From d044f460ff26e1f0e88ca06eace17390f13744c7 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 19 Sep 2019 14:15:33 +0000 Subject: [PATCH 1/5] Add tfio.IOTensor.from_avro support Avro is a columnar file format that naturally fits into a table/column data. Avro file itself is not directly indexable. However, it is pseudo-indexable as it consists of blocks with each blocks specifying file offset/size, and count of items. So indexing coulb be done by small range iteration. It would be desirable to make Avro indexable as it will be much more convenient with increased flexibility. This PR adds tfio.IOTensor.from_avro support so that it is possible to acess avro data through natual __getitem__ operations. Signed-off-by: Yong Tang --- tensorflow_io/avro/kernels/avro_kernels.cc | 231 ++++++++++++++++++ tensorflow_io/avro/ops/avro_ops.cc | 42 ++++ tensorflow_io/avro/python/ops/avro_ops.py | 8 + .../core/python/ops/avro_io_tensor_ops.py | 53 ++++ tensorflow_io/core/python/ops/io_tensor.py | 20 ++ tests/test_avro_eager.py | 50 ++-- 6 files changed, 376 insertions(+), 28 deletions(-) create mode 100644 tensorflow_io/core/python/ops/avro_io_tensor_ops.py diff --git a/tensorflow_io/avro/kernels/avro_kernels.cc b/tensorflow_io/avro/kernels/avro_kernels.cc index 5340cf91d..a9d2d60e0 100644 --- a/tensorflow_io/avro/kernels/avro_kernels.cc +++ b/tensorflow_io/avro/kernels/avro_kernels.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_interface.h" #include "api/DataFile.hh" #include "api/Compiler.hh" #include "api/Generic.hh" @@ -287,6 +288,236 @@ REGISTER_KERNEL_BUILDER(Name("ReadAvro").Device(DEVICE_CPU), ReadAvroOp); + } // namespace + +class AvroIndexable : public IOIndexableInterface { + public: + AvroIndexable(Env* env) + : env_(env) {} + + ~AvroIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + string schema; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("schema: ") == 0) { + schema = metadata[i].substr(8); + } + } + + string error; + std::istringstream ss(schema); + if (!(avro::compileJsonSchema(ss, reader_schema_, error))) { + return errors::Internal("Avro schema error: ", error); + } + + for (int i = 0; i < reader_schema_.root()->names(); i++) { + columns_.push_back(reader_schema_.root()->nameAt(i)); + columns_index_[reader_schema_.root()->nameAt(i)] = i; + } + + avro::GenericDatum datum(reader_schema_.root()); + const avro::GenericRecord& record = datum.value(); + for (size_t i = 0; i < reader_schema_.root()->names(); i++) { + const avro::GenericDatum& field = record.field(columns_[i]); + ::tensorflow::DataType dtype; + switch(field.type()) { + case avro::AVRO_BOOL: + dtype = DT_BOOL; + break; + case avro::AVRO_INT: + dtype = DT_INT32; + break; + case avro::AVRO_LONG: + dtype = DT_INT64; + break; + case avro::AVRO_FLOAT: + dtype = DT_FLOAT; + break; + case avro::AVRO_DOUBLE: + dtype = DT_DOUBLE; + break; + case avro::AVRO_STRING: + dtype = DT_STRING; + break; + case avro::AVRO_BYTES: + dtype = DT_STRING; + break; + case avro::AVRO_FIXED: + dtype = DT_STRING; + break; + case avro::AVRO_ENUM: + dtype = DT_STRING; + break; + default: + return errors::InvalidArgument("Avro type unsupported: ", field.type()); + } + dtypes_.emplace_back(dtype); + } + + // Find out the total number of rows + reader_stream_.reset(new AvroInputStream(file_.get())); + reader_.reset(new avro::DataFileReader(std::move(reader_stream_), reader_schema_)); + + avro::DecoderPtr decoder = avro::binaryDecoder(); + + int64 total = 0; + + reader_->sync(0); + int64 offset = reader_->previousSync(); + while (offset < file_size_) { + StringPiece result; + string buffer(16, 0x00); + TF_RETURN_IF_ERROR(file_->Read(offset, buffer.size(), &result, &buffer[0])); + std::unique_ptr in = avro::memoryInputStream((const uint8_t*)result.data(), result.size()); + decoder->init(*in); + long items = decoder->decodeLong(); + + total += static_cast(items); + positions_.emplace_back(std::pair(static_cast(items), offset)); + + reader_->sync(offset); + offset = reader_->previousSync(); + } + + for (size_t i = 0; i < columns_.size(); i++) { + shapes_.emplace_back(TensorShape({total})); + } + return Status::OK(); + } + Status Components(Tensor* components) override { + *components = Tensor(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + components->flat()(i) = columns_[i]; + } + return Status::OK(); + } + Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override { + if (columns_index_.find(component.scalar()()) == columns_index_.end()) { + return errors::InvalidArgument("component ", component.scalar()(), " is invalid"); + } + int64 column_index = columns_index_[component.scalar()()]; + *shape = shapes_[column_index]; + *dtype = dtypes_[column_index]; + return Status::OK(); + } + + Status Read(const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override { + const string& column = component.scalar()(); + avro::GenericDatum datum(reader_schema_); + + // Find the start sync point + int64 item_index_sync = 0; + for (size_t i = 0; i < positions_.size(); i++, item_index_sync += positions_[i].first) { + if (item_index_sync >= stop) { + continue; + } + if (item_index_sync + positions_[i].first <= start) { + continue; + } + // TODO: Avro is sync point partitioned and each block is very similiar to + // Row Group of parquet. Ideally each block should be cached with the hope + // that slicing and indexing will happend around the same block across multiple + // rows. Caching is not done yet. + + // Seek to sync + reader_->seek(positions_[i].second); + for (int64 item_index = item_index_sync; item_index < (item_index_sync + positions_[i].first) && item_index < stop; item_index++) { + // Read anyway + if (!reader_->read(datum)) { + return errors::Internal("unable to read record at: ", item_index); + } + // Assign only when in range + if (item_index >= start) { + const avro::GenericRecord& record = datum.value(); + const avro::GenericDatum& field = record.field(column); + switch(field.type()) { + case avro::AVRO_BOOL: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_INT: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_LONG: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_FLOAT: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_DOUBLE: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_STRING: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_BYTES: { + const std::vector& field_value = field.value>(); + value->flat()(item_index - start) = string((char *)&field_value[0], field_value.size()); + } + break; + case avro::AVRO_FIXED: { + const std::vector& field_value = field.value().value(); + value->flat()(item_index - start) = string((char *)&field_value[0], field_value.size()); + } + break; + case avro::AVRO_ENUM: + value->flat()(item_index - start) = field.value().symbol(); + break; + default: + return errors::InvalidArgument("unsupported data type: ", field.type()); + } + } + } + } + return Status::OK(); + } + + Status Capacity(std::vector *start, std::vector* stop) { + start->clear(); + stop->clear(); + int64 item_index = 0; + // positions_ are pairs of ,items, offset> + for (size_t i = 0; i < positions_.size(); i++) { + start->emplace_back(item_index); + item_index += positions_[i].first; + stop->emplace_back(item_index); + } + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("AvroIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + avro::ValidSchema reader_schema_; + std::unique_ptr reader_stream_; + std::unique_ptr> reader_; + std::vector> positions_; // pair + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::unordered_map columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("AvroIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("AvroIndexableSpec").Device(DEVICE_CPU), + IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("AvroIndexableRead").Device(DEVICE_CPU), + IOIndexableReadOp); + } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/avro/ops/avro_ops.cc b/tensorflow_io/avro/ops/avro_ops.cc index 1c090dd83..8a5899eac 100644 --- a/tensorflow_io/avro/ops/avro_ops.cc +++ b/tensorflow_io/avro/ops/avro_ops.cc @@ -45,4 +45,46 @@ REGISTER_OP("ReadAvro") return Status::OK(); }); +REGISTER_OP("AvroIndexableInit") + .Input("input: string") + .Input("metadata: string") + .Output("resource: resource") + .Output("component: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("AvroIndexableSpec") + .Input("input: resource") + .Input("component: string") + .Output("shape: int64") + .Output("dtype: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("AvroIndexableRead") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("component: string") + .Output("value: dtype") + .Attr("filter: list(string) = []") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow_io/avro/python/ops/avro_ops.py b/tensorflow_io/avro/python/ops/avro_ops.py index 9a5278210..711d66ecb 100644 --- a/tensorflow_io/avro/python/ops/avro_ops.py +++ b/tensorflow_io/avro/python/ops/avro_ops.py @@ -17,10 +17,18 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow_io.core.python.ops import core_ops from tensorflow_io.core.python.ops import data_ops +warnings.warn( + "The tensorflow_io.avro.AvroDataset is " + "deprecated. Please look for tfio.IOTensor.from_avro " + "for reading Avro files into tensorflow.", + DeprecationWarning) + def list_avro_columns(filename, schema, **kwargs): """list_avro_columns""" if not tf.executing_eagerly(): diff --git a/tensorflow_io/core/python/ops/avro_io_tensor_ops.py b/tensorflow_io/core/python/ops/avro_io_tensor_ops.py new file mode 100644 index 000000000..59c11b621 --- /dev/null +++ b/tensorflow_io/core/python/ops/avro_io_tensor_ops.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AvroIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class AvroIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access + """AvroIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + schema, + internal=False): + with tf.name_scope("AvroIOTensor") as scope: + metadata = ["schema: %s" % schema] + resource, columns = core_ops.avro_indexable_init( + filename, metadata=metadata, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + columns = [column.decode() for column in columns.numpy().tolist()] + spec = [] + for column in columns: + shape, dtype = core_ops.avro_indexable_spec(resource, column) + shape = tf.TensorShape(shape) + dtype = tf.as_dtype(dtype.numpy()) + spec.append(tf.TensorSpec(shape, dtype, column)) + spec = tuple(spec) + super(AvroIOTensor, self).__init__( + spec, columns, + resource, core_ops.avro_indexable_read, + internal=internal) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index e86cf60c4..e74bc037c 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -27,6 +27,7 @@ from tensorflow_io.core.python.ops import prometheus_io_tensor_ops from tensorflow_io.core.python.ops import feather_io_tensor_ops from tensorflow_io.core.python.ops import csv_io_tensor_ops +from tensorflow_io.core.python.ops import avro_io_tensor_ops class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access """IOTensor @@ -382,3 +383,22 @@ def from_csv(cls, """ with tf.name_scope(kwargs.get("name", "IOFromCSV")): return csv_io_tensor_ops.CSVIOTensor(filename, internal=True) + + @classmethod + def from_avro(cls, + filename, + schema, + **kwargs): + """Creates an `IOTensor` from an avro file. + + Args: + filename: A string, the filename of an avro file. + schema: A string, the schema of an avro file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromAvro")): + return avro_io_tensor_ops.AvroIOTensor(filename, schema, internal=True) diff --git a/tests/test_avro_eager.py b/tests/test_avro_eager.py index 172697239..ad99160ed 100644 --- a/tests/test_avro_eager.py +++ b/tests/test_avro_eager.py @@ -12,21 +12,22 @@ # License for the specific language governing permissions and limitations under # the License. # ============================================================================== -"""Tests for AvroDataset.""" +"""Tests for tfio.IOTensor.from_avro.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os +import numpy as np import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.avro as avro_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position def test_avro(): - """test_list_avro_columns.""" + """test_avro""" # The test.bin was created from avro/lang/c++/examples/datafile.cc. filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -39,31 +40,24 @@ def test_avro(): with open(schema_filename, 'r') as f: schema = f.read() - specs = avro_io.list_avro_columns(filename, schema) - assert specs["im"].dtype == tf.float64 - assert specs["re"].dtype == tf.float64 - - v0 = avro_io.read_avro(filename, schema, specs["im"]) - v1 = avro_io.read_avro(filename, schema, specs["re"]) - for i in range(100): - (im, re) = (i + 100, i * 100) - assert v0[i].numpy() == im - assert v1[i].numpy() == re - - for capacity in [10, 20, 50, 100, 1000, 2000]: - dataset = tf.compat.v2.data.Dataset.zip( - ( - avro_io.AvroDataset(filename, schema, "im", capacity=capacity), - avro_io.AvroDataset(filename, schema, "re", capacity=capacity) - ) - ).apply(tf.data.experimental.unbatch()) - i = 0 - for vv in dataset: - v0, v1 = vv - (im, re) = (i + 100, i * 100) - assert v0.numpy() == im - assert v1.numpy() == re - i += 1 + avro = tfio.IOTensor.from_avro(filename, schema) + assert avro("im").dtype == tf.float64 + assert avro("im").shape == [100] + assert avro("re").dtype == tf.float64 + assert avro("re").shape == [100] + + assert np.all( + avro("im").to_tensor().numpy() == [100.0 + i for i in range(100)]) + assert np.all( + avro("re").to_tensor().numpy() == [100.0 * i for i in range(100)]) + + dataset = avro.to_dataset() + i = 0 + for v in dataset: + re, im = v + assert im.numpy() == 100.0 + i + assert re.numpy() == 100.0 * i + i += 1 if __name__ == "__main__": test.main() From ddde2b04eb120f5fea37123e4d70142077cc4e9e Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 19 Sep 2019 17:27:04 +0000 Subject: [PATCH 2/5] Add a Partitions function to Avro, so that it is possible to dynamically adjust the capacity of the chunk size when reading. Signed-off-by: Yong Tang --- tensorflow_io/avro/kernels/avro_kernels.cc | 25 ++++++------ tensorflow_io/avro/ops/avro_ops.cc | 8 ++++ tensorflow_io/core/kernels/io_interface.h | 29 +++++++++++++- .../core/python/ops/avro_io_tensor_ops.py | 11 +++++- .../core/python/ops/csv_io_tensor_ops.py | 1 + .../core/python/ops/feather_io_tensor_ops.py | 1 + .../core/python/ops/io_tensor_ops.py | 35 ++++++++++++++++- .../core/python/ops/json_io_tensor_ops.py | 1 + tests/test_avro_eager.py | 38 +++++++++++++++++++ 9 files changed, 133 insertions(+), 16 deletions(-) diff --git a/tensorflow_io/avro/kernels/avro_kernels.cc b/tensorflow_io/avro/kernels/avro_kernels.cc index a9d2d60e0..d26d724bc 100644 --- a/tensorflow_io/avro/kernels/avro_kernels.cc +++ b/tensorflow_io/avro/kernels/avro_kernels.cc @@ -392,6 +392,16 @@ class AvroIndexable : public IOIndexableInterface { } return Status::OK(); } + + Status Partitions(std::vector *partitions) override { + partitions->clear(); + // positions_ are pairs of + for (size_t i = 0; i < positions_.size(); i++) { + partitions->emplace_back(positions_[i].first); + } + return Status::OK(); + } + Status Components(Tensor* components) override { *components = Tensor(DT_STRING, TensorShape({static_cast(columns_.size())})); for (size_t i = 0; i < columns_.size(); i++) { @@ -479,19 +489,6 @@ class AvroIndexable : public IOIndexableInterface { return Status::OK(); } - Status Capacity(std::vector *start, std::vector* stop) { - start->clear(); - stop->clear(); - int64 item_index = 0; - // positions_ are pairs of ,items, offset> - for (size_t i = 0; i < positions_.size(); i++) { - start->emplace_back(item_index); - item_index += positions_[i].first; - stop->emplace_back(item_index); - } - return Status::OK(); - } - string DebugString() const override { mutex_lock l(mu_); return strings::StrCat("AvroIndexable"); @@ -516,6 +513,8 @@ REGISTER_KERNEL_BUILDER(Name("AvroIndexableInit").Device(DEVICE_CPU), IOInterfaceInitOp); REGISTER_KERNEL_BUILDER(Name("AvroIndexableSpec").Device(DEVICE_CPU), IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("AvroIndexablePartitions").Device(DEVICE_CPU), + IOIndexablePartitionsOp); REGISTER_KERNEL_BUILDER(Name("AvroIndexableRead").Device(DEVICE_CPU), IOIndexableReadOp); diff --git a/tensorflow_io/avro/ops/avro_ops.cc b/tensorflow_io/avro/ops/avro_ops.cc index 8a5899eac..4d5f0f22f 100644 --- a/tensorflow_io/avro/ops/avro_ops.cc +++ b/tensorflow_io/avro/ops/avro_ops.cc @@ -87,4 +87,12 @@ REGISTER_OP("AvroIndexableRead") return Status::OK(); }); +REGISTER_OP("AvroIndexablePartitions") + .Input("input: resource") + .Output("partitions: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/io_interface.h b/tensorflow_io/core/kernels/io_interface.h index 83a112328..b627b62a8 100644 --- a/tensorflow_io/core/kernels/io_interface.h +++ b/tensorflow_io/core/kernels/io_interface.h @@ -25,9 +25,13 @@ class IOInterface : public ResourceBase { virtual Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) = 0; virtual Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) = 0; + virtual Status Partitions(std::vector *partitions) { + // By default partitions is not implemented: Unimplemented + return errors::Unimplemented("Patitions"); + } virtual Status Components(Tensor* components) { // By default there is only one component: Unimplemented - return errors::Unimplemented("Component"); + return errors::Unimplemented("Components"); } virtual Status Extra(const Tensor& component, std::vector* extra) { // This is the chance to provide additional extra information which should be appended to extra. @@ -451,5 +455,28 @@ class IOMappingReadOp : public OpKernel { context->set_output(0, value); } }; +template +class IOIndexablePartitionsOp : public OpKernel { + public: + explicit IOIndexablePartitionsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + } + + void Compute(OpKernelContext* context) override { + Type* resource; + OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); + core::ScopedUnref unref(resource); + + std::vector partitions; + OP_REQUIRES_OK(context, resource->Partitions(&partitions)); + + Tensor partitions_tensor(DT_INT64, TensorShape({static_cast(partitions.size())})); + for (size_t i = 0; i < partitions.size(); i++) { + partitions_tensor.flat()(i) = partitions[i]; + } + + context->set_output(0, partitions_tensor); + } +}; } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/avro_io_tensor_ops.py b/tensorflow_io/core/python/ops/avro_io_tensor_ops.py index 59c11b621..0c5eaa418 100644 --- a/tensorflow_io/core/python/ops/avro_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/avro_io_tensor_ops.py @@ -32,6 +32,7 @@ class AvroIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-ac def __init__(self, filename, schema, + capacity=None, internal=False): with tf.name_scope("AvroIOTensor") as scope: metadata = ["schema: %s" % schema] @@ -39,6 +40,14 @@ def __init__(self, filename, metadata=metadata, container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + partitions = None + if capacity is not None: + partitions = core_ops.avro_indexable_partitions(resource) + partitions = partitions.numpy().tolist() + if capacity > 0: + partitions = [ + v for e in partitions for v in list( + [capacity] * (e // capacity) + [e % capacity])] columns = [column.decode() for column in columns.numpy().tolist()] spec = [] for column in columns: @@ -50,4 +59,4 @@ def __init__(self, super(AvroIOTensor, self).__init__( spec, columns, resource, core_ops.avro_indexable_read, - internal=internal) + partitions=partitions, internal=internal) diff --git a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py index 8dad814b3..5dac6be3e 100644 --- a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py @@ -48,6 +48,7 @@ def __init__(self, super(CSVIOTensor, self).__init__( spec, columns, resource, core_ops.csv_indexable_read, + partitions=None, internal=internal) #============================================================================= diff --git a/tensorflow_io/core/python/ops/feather_io_tensor_ops.py b/tensorflow_io/core/python/ops/feather_io_tensor_ops.py index 18f281598..6bf72bbbb 100644 --- a/tensorflow_io/core/python/ops/feather_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/feather_io_tensor_ops.py @@ -48,4 +48,5 @@ def __init__(self, super(FeatherIOTensor, self).__init__( spec, columns, resource, core_ops.feather_indexable_read, + partitions=None, internal=internal) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index d084e9003..da58d04a3 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -277,10 +277,12 @@ def __init__(self, columns, resource, function, + partitions, internal=False): self._columns = columns self._resource = resource self._function = function + self._partitions = partitions super(_TableIOTensor, self).__init__( spec, internal=internal) @@ -310,9 +312,40 @@ def __call__(self, resource, start, stop): resource, start=start, stop=stop, component=self._component, shape=self._shape, dtype=self._dtype) + class _PartitionedFunction(object): + """PartitionedFunction will translate call to cached Function call""" + def __init__(self, func, partitions): + self._func = func + self._partitions = partitions + partitions_indices = tf.cumsum(partitions).numpy().tolist() + self._partitions_start = list([0] + partitions_indices[:-1]) + self._partitions_stop = partitions_indices + self._tensors = [None for _ in partitions] + def __call__(self, resource, start, stop): + indices_start = tf.math.maximum(self._partitions_start, start) + indices_stop = tf.math.minimum(self._partitions_stop, stop) + indices_hit = tf.math.less(indices_start, indices_stop) + indices = tf.squeeze(tf.compat.v2.where(indices_hit), [1]) + items = [] + # TODO: change to tf.while_loop + for index in indices: + if self._tensors[index] is None: + self._tensors[index] = self._func( + resource, + self._partitions_start[index], + self._partitions_stop[index]) + slice_start = indices_start[index] - self._partitions_start[index] + slice_size = indices_stop[index] - indices_start[index] + item = tf.slice(self._tensors[index], [slice_start], [slice_size]) + items.append(item) + return tf.concat(items, axis=0) + + function = _Function(self._function, spec, column) + if self._partitions is not None: + function = _PartitionedFunction(function, self._partitions) return BaseIOTensor( - spec, self._resource, _Function(self._function, spec, column), + spec, self._resource, function, internal=True) #============================================================================= diff --git a/tensorflow_io/core/python/ops/json_io_tensor_ops.py b/tensorflow_io/core/python/ops/json_io_tensor_ops.py index 64ab43e4c..6bea4a21b 100644 --- a/tensorflow_io/core/python/ops/json_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/json_io_tensor_ops.py @@ -51,4 +51,5 @@ def __init__(self, super(JSONIOTensor, self).__init__( spec, columns, resource, core_ops.json_indexable_read, + partitions=None, internal=internal) diff --git a/tests/test_avro_eager.py b/tests/test_avro_eager.py index ad99160ed..fd0997250 100644 --- a/tests/test_avro_eager.py +++ b/tests/test_avro_eager.py @@ -25,6 +25,7 @@ if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() import tensorflow_io as tfio # pylint: disable=wrong-import-position +from tensorflow_io.core.python.ops import avro_io_tensor_ops # pylint: disable=wrong-import-position def test_avro(): """test_avro""" @@ -59,5 +60,42 @@ def test_avro(): assert re.numpy() == 100.0 * i i += 1 +def test_avro_partition(): + """test_avro_partition""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename + + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() + for capacity in [ + 1, 2, 3, + 11, 12, 13, + 50, 51, 100]: + avro = avro_io_tensor_ops.AvroIOTensor( + filename, schema, capacity=capacity, internal=True) + assert np.all( + avro("im").to_tensor().numpy() == [100.0 + i for i in range(100)]) + assert np.all( + avro("re").to_tensor().numpy() == [100.0 * i for i in range(100)]) + for step in [ + 1, 2, 3, + 10, 11, 12, 13, + 50, 51, 52, 53]: + indices = range(0, 100, step) + for start, stop in zip(indices, indices[1:] + [100]): + im_expected = [100.0 + i for i in range(start, stop)] + im_items = avro("im")[start:stop] + assert np.all(im_items.numpy() == im_expected) + + re_expected = [100.0 * i for i in range(start, stop)] + re_items = avro("re")[start:stop] + assert np.all(re_items.numpy() == re_expected) + if __name__ == "__main__": test.main() From f23347fe7022593fae7f06e6b8f3ff800468a045 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 19 Sep 2019 17:38:07 +0000 Subject: [PATCH 3/5] Rename to io_stream.h for consistency Signed-off-by: Yong Tang --- tensorflow_io/arrow/kernels/arrow_dataset_ops.cc | 2 +- tensorflow_io/arrow/kernels/arrow_kernels.h | 2 +- tensorflow_io/audio/kernels/audio_kernels.cc | 2 +- tensorflow_io/avro/kernels/avro_kernels.cc | 2 +- tensorflow_io/core/BUILD | 2 +- tensorflow_io/core/kernels/archive_kernels.cc | 2 +- tensorflow_io/core/kernels/{stream.h => io_stream.h} | 0 tensorflow_io/hdf5/kernels/hdf5_kernels.cc | 2 +- tensorflow_io/json/kernels/json_kernels.cc | 3 +-- tensorflow_io/lmdb/kernels/lmdb_kernels.cc | 2 +- tensorflow_io/text/kernels/csv_kernels.cc | 3 +-- tensorflow_io/text/kernels/text_kernels.cc | 2 +- 12 files changed, 11 insertions(+), 13 deletions(-) rename tensorflow_io/core/kernels/{stream.h => io_stream.h} (100%) diff --git a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc index cbe0c9a4e..c9811edc9 100644 --- a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc +++ b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "arrow/util/io-util.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "tensorflow_io/arrow/kernels/arrow_kernels.h" #include "tensorflow_io/arrow/kernels/arrow_stream_client.h" #include "tensorflow_io/arrow/kernels/arrow_util.h" diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.h b/tensorflow_io/arrow/kernels/arrow_kernels.h index 39abef014..6792b092a 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.h +++ b/tensorflow_io/arrow/kernels/arrow_kernels.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_IO_ARROW_KERNELS_H_ #define TENSORFLOW_IO_ARROW_KERNELS_H_ -#include "kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "arrow/io/api.h" #include "arrow/buffer.h" #include "arrow/type.h" diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc index 8d3c1e4d0..5d33b2387 100644 --- a/tensorflow_io/audio/kernels/audio_kernels.cc +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" namespace tensorflow { namespace data { diff --git a/tensorflow_io/avro/kernels/avro_kernels.cc b/tensorflow_io/avro/kernels/avro_kernels.cc index d26d724bc..20c30359d 100644 --- a/tensorflow_io/avro/kernels/avro_kernels.cc +++ b/tensorflow_io/avro/kernels/avro_kernels.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" #include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "api/DataFile.hh" #include "api/Compiler.hh" #include "api/Generic.hh" diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index b3a785306..2cc14f179 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -30,7 +30,7 @@ cc_library( srcs = [ "kernels/dataset_ops.h", "kernels/io_interface.h", - "kernels/stream.h", + "kernels/io_stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/archive_kernels.cc b/tensorflow_io/core/kernels/archive_kernels.cc index 2149687b9..5b7358fd7 100644 --- a/tensorflow_io/core/kernels/archive_kernels.cc +++ b/tensorflow_io/core/kernels/archive_kernels.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" namespace tensorflow { namespace data { diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/io_stream.h similarity index 100% rename from tensorflow_io/core/kernels/stream.h rename to tensorflow_io/core/kernels/io_stream.h diff --git a/tensorflow_io/hdf5/kernels/hdf5_kernels.cc b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc index 0673b2440..2e02f3654 100644 --- a/tensorflow_io/hdf5/kernels/hdf5_kernels.cc +++ b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include #include diff --git a/tensorflow_io/json/kernels/json_kernels.cc b/tensorflow_io/json/kernels/json_kernels.cc index 98bca657b..3ac3f00fe 100644 --- a/tensorflow_io/json/kernels/json_kernels.cc +++ b/tensorflow_io/json/kernels/json_kernels.cc @@ -17,12 +17,11 @@ limitations under the License. #include #include "kernels/dataset_ops.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/platform/env.h" #include "include/json/json.h" #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "arrow/memory_pool.h" #include "arrow/json/reader.h" #include "arrow/table.h" diff --git a/tensorflow_io/lmdb/kernels/lmdb_kernels.cc b/tensorflow_io/lmdb/kernels/lmdb_kernels.cc index bd6f85a79..966ef7228 100644 --- a/tensorflow_io/lmdb/kernels/lmdb_kernels.cc +++ b/tensorflow_io/lmdb/kernels/lmdb_kernels.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" #include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include #include #include "lmdb.h" diff --git a/tensorflow_io/text/kernels/csv_kernels.cc b/tensorflow_io/text/kernels/csv_kernels.cc index 86ee70a57..c2318af1b 100644 --- a/tensorflow_io/text/kernels/csv_kernels.cc +++ b/tensorflow_io/text/kernels/csv_kernels.cc @@ -14,10 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "arrow/memory_pool.h" #include "arrow/csv/reader.h" #include "arrow/table.h" diff --git a/tensorflow_io/text/kernels/text_kernels.cc b/tensorflow_io/text/kernels/text_kernels.cc index 0481237a9..8f99fc468 100644 --- a/tensorflow_io/text/kernels/text_kernels.cc +++ b/tensorflow_io/text/kernels/text_kernels.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "tensorflow_io/core/kernels/stream.h" namespace tensorflow { namespace data { From 20063ad6559f5964b75375aa0a330ba0c1f096af Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 19 Sep 2019 17:53:42 +0000 Subject: [PATCH 4/5] Remove the need to pass component, unless needed explicitly Signed-off-by: Yong Tang --- tensorflow_io/audio/ops/audio_ops.cc | 2 -- tensorflow_io/core/kernels/io_interface.h | 29 ++++++++++++++++--- .../core/python/ops/audio_io_tensor_ops.py | 6 ++-- .../core/python/ops/kafka_dataset_ops.py | 2 +- .../core/python/ops/kafka_io_tensor_ops.py | 1 - .../core/python/ops/lmdb_io_tensor_ops.py | 2 +- tensorflow_io/kafka/ops/kafka_ops.cc | 3 -- tensorflow_io/lmdb/ops/lmdb_ops.cc | 1 - 8 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tensorflow_io/audio/ops/audio_ops.cc b/tensorflow_io/audio/ops/audio_ops.cc index 420f5882a..f194faa03 100644 --- a/tensorflow_io/audio/ops/audio_ops.cc +++ b/tensorflow_io/audio/ops/audio_ops.cc @@ -31,7 +31,6 @@ REGISTER_OP("WAVIndexableInit") REGISTER_OP("WAVIndexableSpec") .Input("input: resource") - .Input("component: int64") .Output("shape: int64") .Output("dtype: int64") .Output("rate: int32") @@ -46,7 +45,6 @@ REGISTER_OP("WAVIndexableRead") .Input("input: resource") .Input("start: int64") .Input("stop: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") diff --git a/tensorflow_io/core/kernels/io_interface.h b/tensorflow_io/core/kernels/io_interface.h index b627b62a8..fd6f1c697 100644 --- a/tensorflow_io/core/kernels/io_interface.h +++ b/tensorflow_io/core/kernels/io_interface.h @@ -249,8 +249,15 @@ class IOInterfaceSpecOp : public OpKernel { OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); core::ScopedUnref unref(resource); + Status status; + + Tensor component_empty(DT_INT64, TensorShape({})); + component_empty.scalar()() = 0; const Tensor* component; - OP_REQUIRES_OK(context, context->input("component", &component)); + status = context->input("component", &component); + if (!status.ok()) { + component = &component_empty; + } PartialTensorShape shape; DataType dtype; @@ -266,7 +273,7 @@ class IOInterfaceSpecOp : public OpKernel { context->set_output(1, dtype_tensor); std::vector extra; - Status status = resource->Extra(*component, &extra); + status = resource->Extra(*component, &extra); if (!errors::IsUnimplemented(status)) { OP_REQUIRES_OK(context, status); for (size_t i = 0; i < extra.size(); i++) { @@ -308,8 +315,15 @@ class IOIterableNextOp : public OpKernel { OP_REQUIRES_OK(context, context->input("capacity", &capacity_tensor)); const int64 capacity = capacity_tensor->scalar()(); + Status status; + + Tensor component_empty(DT_INT64, TensorShape({})); + component_empty.scalar()() = 0; const Tensor* component; - OP_REQUIRES_OK(context, context->input("component", &component)); + status = context->input("component", &component); + if (!status.ok()) { + component = &component_empty; + } OP_REQUIRES(context, (capacity > 0), errors::InvalidArgument("capacity <= 0 is not supported: ", capacity)); @@ -403,8 +417,15 @@ class IOIndexableReadOp : public OpKernel { OP_REQUIRES_OK(context, context->input("stop", &stop_tensor)); int64 stop = stop_tensor->scalar()(); + Status status; + + Tensor component_empty(DT_INT64, TensorShape({})); + component_empty.scalar()() = 0; const Tensor* component; - OP_REQUIRES_OK(context, context->input("component", &component)); + status = context->input("component", &component); + if (!status.ok()) { + component = &component_empty; + } int64 output_index = 0; Tensor* value_tensor = nullptr; diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py index bc8a803aa..1bb92a41d 100644 --- a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -44,22 +44,20 @@ def __init__(self, filename, container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex)) - shape, dtype, rate = core_ops.wav_indexable_spec(resource, component=0) + shape, dtype, rate = core_ops.wav_indexable_spec(resource) shape = tf.TensorShape(shape.numpy()) dtype = tf.as_dtype(dtype.numpy()) spec = tf.TensorSpec(shape, dtype) class _Function(object): - def __init__(self, func, shape, dtype, component=0): + def __init__(self, func, shape, dtype): self._func = func self._shape = tf.TensorShape([None]).concatenate(shape[1:]) self._dtype = dtype - self._component = component def __call__(self, resource, start, stop): return self._func( resource, start=start, stop=stop, - component=self._component, shape=self._shape, dtype=self._dtype) self._rate = rate.numpy() diff --git a/tensorflow_io/core/python/ops/kafka_dataset_ops.py b/tensorflow_io/core/python/ops/kafka_dataset_ops.py index c352d97ba..e413615e0 100644 --- a/tensorflow_io/core/python/ops/kafka_dataset_ops.py +++ b/tensorflow_io/core/python/ops/kafka_dataset_ops.py @@ -62,7 +62,7 @@ def __init__(self, dataset = tf.compat.v2.data.Dataset.range(0, sys.maxsize, capacity) dataset = dataset.map( lambda i: core_ops.kafka_iterable_next( - resource, capacity, component=0, + resource, capacity, dtype=tf.string, shape=tf.TensorShape([None]))) dataset = dataset.apply( tf.data.experimental.take_while( diff --git a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py index 381c18817..90ff6103f 100644 --- a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py @@ -58,7 +58,6 @@ def __init__(self, func, spec): def __call__(self, resource, start, stop): return self._func( resource, start=start, stop=stop, - component=0, shape=self._shape, dtype=self._dtype) self._iterable = iterable diff --git a/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py b/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py index 34482da2e..e5044ffbf 100644 --- a/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py @@ -64,7 +64,7 @@ def __init__(self, func, shape, dtype): self._dtype = dtype def __call__(self, resource, capacity): return self._func( - resource, capacity, component=0, + resource, capacity, shape=self._shape, dtype=self._dtype) super(LMDBIOTensor, self).__init__( diff --git a/tensorflow_io/kafka/ops/kafka_ops.cc b/tensorflow_io/kafka/ops/kafka_ops.cc index 6de3e420b..28d8733e0 100644 --- a/tensorflow_io/kafka/ops/kafka_ops.cc +++ b/tensorflow_io/kafka/ops/kafka_ops.cc @@ -33,7 +33,6 @@ REGISTER_OP("KafkaIndexableInit") REGISTER_OP("KafkaIndexableSpec") .Input("input: resource") - .Input("component: int64") .Output("shape: int64") .Output("dtype: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { @@ -46,7 +45,6 @@ REGISTER_OP("KafkaIndexableRead") .Input("input: resource") .Input("start: int64") .Input("stop: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") @@ -73,7 +71,6 @@ REGISTER_OP("KafkaIterableInit") REGISTER_OP("KafkaIterableNext") .Input("input: resource") .Input("capacity: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") diff --git a/tensorflow_io/lmdb/ops/lmdb_ops.cc b/tensorflow_io/lmdb/ops/lmdb_ops.cc index c08b36041..367384ec4 100644 --- a/tensorflow_io/lmdb/ops/lmdb_ops.cc +++ b/tensorflow_io/lmdb/ops/lmdb_ops.cc @@ -56,7 +56,6 @@ REGISTER_OP("LMDBIterableInit") REGISTER_OP("LMDBIterableNext") .Input("input: resource") .Input("capacity: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") From 5759a72d179f4eb512143d398a05d6f9a10d83bc Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 19 Sep 2019 20:37:18 +0000 Subject: [PATCH 5/5] Move Partitions to a generic location and support dataset Signed-off-by: Yong Tang --- .../core/python/ops/audio_io_tensor_ops.py | 2 +- .../core/python/ops/csv_io_tensor_ops.py | 2 +- tensorflow_io/core/python/ops/io_tensor.py | 3 +- .../core/python/ops/io_tensor_ops.py | 107 ++++++++++-------- .../core/python/ops/kafka_io_tensor_ops.py | 2 +- tests/test_avro_eager.py | 31 ++++- 6 files changed, 92 insertions(+), 55 deletions(-) diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py index 1bb92a41d..55ae200bd 100644 --- a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -64,7 +64,7 @@ def __call__(self, resource, start, stop): super(AudioIOTensor, self).__init__( spec, resource, _Function(core_ops.wav_indexable_read, spec.shape, spec.dtype), - internal=internal) + partitions=None, internal=internal) #============================================================================= # Accessors diff --git a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py index 5dac6be3e..239f87bd8 100644 --- a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py @@ -77,4 +77,4 @@ def __call__(self, resource, start, stop): return io_tensor_ops.BaseIOTensor( spec, self._resource, _Function(core_ops.csv_indexable_read, spec, column), - internal=True) + partitions=None, internal=True) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index e74bc037c..446751745 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -401,4 +401,5 @@ def from_avro(cls, """ with tf.name_scope(kwargs.get("name", "IOFromAvro")): - return avro_io_tensor_ops.AvroIOTensor(filename, schema, internal=True) + return avro_io_tensor_ops.AvroIOTensor( + filename, schema, internal=True, **kwargs) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index da58d04a3..17ff0bde1 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -25,20 +25,52 @@ class _IOTensorMeta(property): """_IOTensorMeta is a decorator that is viewable to __repr__""" pass +class _IOTensorPartitionedFunction(object): + """PartitionedFunction will translate call to cached Function call""" + def __init__(self, func, partitions): + self._func = func + self._partitions = partitions + partitions_indices = tf.cumsum(partitions).numpy().tolist() + self._partitions_start = list([0] + partitions_indices[:-1]) + self._partitions_stop = partitions_indices + self._tensors = [None for _ in partitions] + + def __call__(self, resource, start, stop): + indices_start = tf.math.maximum(self._partitions_start, start) + indices_stop = tf.math.minimum(self._partitions_stop, stop) + indices_hit = tf.math.less(indices_start, indices_stop) + indices = tf.squeeze(tf.compat.v2.where(indices_hit), [1]) + items = [] + # TODO: change to tf.while_loop + for index in indices: + if self._tensors[index] is None: + self._tensors[index] = self._func( + resource, + self._partitions_start[index], + self._partitions_stop[index]) + slice_start = indices_start[index] - self._partitions_start[index] + slice_size = indices_stop[index] - indices_start[index] + item = tf.slice(self._tensors[index], [slice_start], [slice_size]) + items.append(item) + return tf.concat(items, axis=0) + class _IOTensorDataset(tf.compat.v2.data.Dataset): """_IOTensorDataset""" - def __init__(self, spec, resource, function): - start = 0 - stop = tf.nest.flatten(spec)[0].shape[0] - capacity = 4096 - entry_start = list(range(start, stop, capacity)) - entry_stop = entry_start[1:] + [stop] - + def __init__(self, spec, resource, function, partitions): + if partitions is None: + start = 0 + stop = tf.nest.flatten(spec)[0].shape[0] + capacity = 4096 + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] + else: + partitions = tf.cast(partitions, tf.int64) + entry_stop = tf.cumsum(partitions) + entry_start = tf.concat([[0], entry_stop[:-1]], axis=0) dataset = tf.compat.v2.data.Dataset.from_tensor_slices(( tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64))) - dataset = dataset.map(lambda start, stop: function(resource, start=start, stop=stop)) dataset = dataset.unbatch() @@ -107,9 +139,15 @@ def __init__(self, spec, resource, function, + partitions, internal=False): + # function used for dataset should not be partitioned. + self._dataset_function = function + if partitions is not None: + function = _IOTensorPartitionedFunction(function, partitions) self._resource = resource self._function = function + self._partitions = partitions super(BaseIOTensor, self).__init__( spec, internal=internal) @@ -145,7 +183,7 @@ def to_dataset(self): A `tf.data.Dataset` with value obtained from this `IOTensor`. """ return _IOTensorDataset( - self.spec, self._resource, self._function) + self.spec, self._resource, self._dataset_function, self._partitions) #============================================================================= # Indexing & Slicing @@ -195,7 +233,7 @@ def __call__(self, resource, start, stop): return BaseIOTensor(spec, self._resource, _Function(self._function, spec, size), - internal=True) + partitions=None, internal=True) #============================================================================= # Tensor Type Conversions @@ -246,7 +284,7 @@ def __call__(self, resource, start, stop): super(TensorIOTensor, self).__init__( tf.TensorSpec(tensor.shape, tensor.dtype), - tensor, _Function(tensor), internal=internal) + tensor, _Function(tensor), None, internal=internal) #============================================================================= # Tensor Type Conversions @@ -312,40 +350,9 @@ def __call__(self, resource, start, stop): resource, start=start, stop=stop, component=self._component, shape=self._shape, dtype=self._dtype) - class _PartitionedFunction(object): - """PartitionedFunction will translate call to cached Function call""" - def __init__(self, func, partitions): - self._func = func - self._partitions = partitions - partitions_indices = tf.cumsum(partitions).numpy().tolist() - self._partitions_start = list([0] + partitions_indices[:-1]) - self._partitions_stop = partitions_indices - self._tensors = [None for _ in partitions] - - def __call__(self, resource, start, stop): - indices_start = tf.math.maximum(self._partitions_start, start) - indices_stop = tf.math.minimum(self._partitions_stop, stop) - indices_hit = tf.math.less(indices_start, indices_stop) - indices = tf.squeeze(tf.compat.v2.where(indices_hit), [1]) - items = [] - # TODO: change to tf.while_loop - for index in indices: - if self._tensors[index] is None: - self._tensors[index] = self._func( - resource, - self._partitions_start[index], - self._partitions_stop[index]) - slice_start = indices_start[index] - self._partitions_start[index] - slice_size = indices_stop[index] - indices_start[index] - item = tf.slice(self._tensors[index], [slice_start], [slice_size]) - items.append(item) - return tf.concat(items, axis=0) - function = _Function(self._function, spec, column) - if self._partitions is not None: - function = _PartitionedFunction(function, self._partitions) return BaseIOTensor( - spec, self._resource, function, + spec, self._resource, function, self._partitions, internal=True) #============================================================================= @@ -387,7 +394,7 @@ def __call__(self, resource, start, stop): return _IOTensorDataset( self.spec, self._resource, - _Function(self._function, self.spec, self.columns)) + _Function(self._function, self.spec, self.columns), self._partitions) class _CollectionIOTensor(_IOTensor): @@ -444,7 +451,7 @@ def __call__(self, resource, start, stop): return BaseIOTensor( spec, self._resource, _Function(self._function, spec, key), - internal=True) + partitions=None, internal=True) class _SeriesIOTensor(_IOTensor): """_SeriesIOTensor""" @@ -480,13 +487,19 @@ def __call__(self, resource, start, stop): def index(self): """The index column of the series""" return BaseIOTensor( - self.spec[0], self._resource, self._index_function, internal=True) + self.spec[0], + self._resource, + self._index_function, + partitions=None, internal=True) @property def value(self): """The value column of the series""" return BaseIOTensor( - self.spec[1], self._resource, self._value_function, internal=True) + self.spec[1], + self._resource, + self._value_function, + partitions=None, internal=True) class _KeyValueIOTensorDataset(tf.compat.v2.data.Dataset): """_KeyValueIOTensorDataset""" diff --git a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py index 90ff6103f..b7fc2d2de 100644 --- a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py @@ -64,4 +64,4 @@ def __call__(self, resource, start, stop): super(KafkaIOTensor, self).__init__( spec, resource, _Function(core_ops.kafka_indexable_read, spec), - internal=internal) + partitions=None, internal=internal) diff --git a/tests/test_avro_eager.py b/tests/test_avro_eager.py index fd0997250..a55e1c2f8 100644 --- a/tests/test_avro_eager.py +++ b/tests/test_avro_eager.py @@ -25,7 +25,6 @@ if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() import tensorflow_io as tfio # pylint: disable=wrong-import-position -from tensorflow_io.core.python.ops import avro_io_tensor_ops # pylint: disable=wrong-import-position def test_avro(): """test_avro""" @@ -77,8 +76,8 @@ def test_avro_partition(): 1, 2, 3, 11, 12, 13, 50, 51, 100]: - avro = avro_io_tensor_ops.AvroIOTensor( - filename, schema, capacity=capacity, internal=True) + avro = tfio.IOTensor.from_avro( + filename, schema, capacity=capacity) assert np.all( avro("im").to_tensor().numpy() == [100.0 + i for i in range(100)]) assert np.all( @@ -87,7 +86,7 @@ def test_avro_partition(): 1, 2, 3, 10, 11, 12, 13, 50, 51, 52, 53]: - indices = range(0, 100, step) + indices = list(range(0, 100, step)) for start, stop in zip(indices, indices[1:] + [100]): im_expected = [100.0 + i for i in range(start, stop)] im_items = avro("im")[start:stop] @@ -97,5 +96,29 @@ def test_avro_partition(): re_items = avro("re")[start:stop] assert np.all(re_items.numpy() == re_expected) +def test_avro_dataset_partition(): + """test_avro_dataset_partition""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename + + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() + for capacity in [1, 2, 3, 11, 12, 13, 50, 51, 100]: + avro = tfio.IOTensor.from_avro( + filename, schema, capacity=capacity) + dataset = avro.to_dataset() + i = 0 + for v in dataset: + re, im = v + assert im.numpy() == 100.0 + i + assert re.numpy() == 100.0 * i + i += 1 + if __name__ == "__main__": test.main()