diff --git a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc index cbe0c9a4e..c9811edc9 100644 --- a/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc +++ b/tensorflow_io/arrow/kernels/arrow_dataset_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "arrow/util/io-util.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "tensorflow_io/arrow/kernels/arrow_kernels.h" #include "tensorflow_io/arrow/kernels/arrow_stream_client.h" #include "tensorflow_io/arrow/kernels/arrow_util.h" diff --git a/tensorflow_io/arrow/kernels/arrow_kernels.h b/tensorflow_io/arrow/kernels/arrow_kernels.h index 39abef014..6792b092a 100644 --- a/tensorflow_io/arrow/kernels/arrow_kernels.h +++ b/tensorflow_io/arrow/kernels/arrow_kernels.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_IO_ARROW_KERNELS_H_ #define TENSORFLOW_IO_ARROW_KERNELS_H_ -#include "kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "arrow/io/api.h" #include "arrow/buffer.h" #include "arrow/type.h" diff --git a/tensorflow_io/audio/kernels/audio_kernels.cc b/tensorflow_io/audio/kernels/audio_kernels.cc index 8d3c1e4d0..5d33b2387 100644 --- a/tensorflow_io/audio/kernels/audio_kernels.cc +++ b/tensorflow_io/audio/kernels/audio_kernels.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" namespace tensorflow { namespace data { diff --git a/tensorflow_io/audio/ops/audio_ops.cc b/tensorflow_io/audio/ops/audio_ops.cc index 420f5882a..f194faa03 100644 --- a/tensorflow_io/audio/ops/audio_ops.cc +++ b/tensorflow_io/audio/ops/audio_ops.cc @@ -31,7 +31,6 @@ REGISTER_OP("WAVIndexableInit") REGISTER_OP("WAVIndexableSpec") .Input("input: resource") - .Input("component: int64") .Output("shape: int64") .Output("dtype: int64") .Output("rate: int32") @@ -46,7 +45,6 @@ REGISTER_OP("WAVIndexableRead") .Input("input: resource") .Input("start: int64") .Input("stop: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") diff --git a/tensorflow_io/avro/kernels/avro_kernels.cc b/tensorflow_io/avro/kernels/avro_kernels.cc index 5340cf91d..20c30359d 100644 --- a/tensorflow_io/avro/kernels/avro_kernels.cc +++ b/tensorflow_io/avro/kernels/avro_kernels.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "api/DataFile.hh" #include "api/Compiler.hh" #include "api/Generic.hh" @@ -287,6 +288,235 @@ REGISTER_KERNEL_BUILDER(Name("ReadAvro").Device(DEVICE_CPU), ReadAvroOp); + } // namespace + +class AvroIndexable : public IOIndexableInterface { + public: + AvroIndexable(Env* env) + : env_(env) {} + + ~AvroIndexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + string schema; + for (size_t i = 0; i < metadata.size(); i++) { + if (metadata[i].find_first_of("schema: ") == 0) { + schema = metadata[i].substr(8); + } + } + + string error; + std::istringstream ss(schema); + if (!(avro::compileJsonSchema(ss, reader_schema_, error))) { + return errors::Internal("Avro schema error: ", error); + } + + for (int i = 0; i < reader_schema_.root()->names(); i++) { + columns_.push_back(reader_schema_.root()->nameAt(i)); + columns_index_[reader_schema_.root()->nameAt(i)] = i; + } + + avro::GenericDatum datum(reader_schema_.root()); + const avro::GenericRecord& record = datum.value(); + for (size_t i = 0; i < reader_schema_.root()->names(); i++) { + const avro::GenericDatum& field = record.field(columns_[i]); + ::tensorflow::DataType dtype; + switch(field.type()) { + case avro::AVRO_BOOL: + dtype = DT_BOOL; + break; + case avro::AVRO_INT: + dtype = DT_INT32; + break; + case avro::AVRO_LONG: + dtype = DT_INT64; + break; + case avro::AVRO_FLOAT: + dtype = DT_FLOAT; + break; + case avro::AVRO_DOUBLE: + dtype = DT_DOUBLE; + break; + case avro::AVRO_STRING: + dtype = DT_STRING; + break; + case avro::AVRO_BYTES: + dtype = DT_STRING; + break; + case avro::AVRO_FIXED: + dtype = DT_STRING; + break; + case avro::AVRO_ENUM: + dtype = DT_STRING; + break; + default: + return errors::InvalidArgument("Avro type unsupported: ", field.type()); + } + dtypes_.emplace_back(dtype); + } + + // Find out the total number of rows + reader_stream_.reset(new AvroInputStream(file_.get())); + reader_.reset(new avro::DataFileReader(std::move(reader_stream_), reader_schema_)); + + avro::DecoderPtr decoder = avro::binaryDecoder(); + + int64 total = 0; + + reader_->sync(0); + int64 offset = reader_->previousSync(); + while (offset < file_size_) { + StringPiece result; + string buffer(16, 0x00); + TF_RETURN_IF_ERROR(file_->Read(offset, buffer.size(), &result, &buffer[0])); + std::unique_ptr in = avro::memoryInputStream((const uint8_t*)result.data(), result.size()); + decoder->init(*in); + long items = decoder->decodeLong(); + + total += static_cast(items); + positions_.emplace_back(std::pair(static_cast(items), offset)); + + reader_->sync(offset); + offset = reader_->previousSync(); + } + + for (size_t i = 0; i < columns_.size(); i++) { + shapes_.emplace_back(TensorShape({total})); + } + return Status::OK(); + } + + Status Partitions(std::vector *partitions) override { + partitions->clear(); + // positions_ are pairs of + for (size_t i = 0; i < positions_.size(); i++) { + partitions->emplace_back(positions_[i].first); + } + return Status::OK(); + } + + Status Components(Tensor* components) override { + *components = Tensor(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + components->flat()(i) = columns_[i]; + } + return Status::OK(); + } + Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) override { + if (columns_index_.find(component.scalar()()) == columns_index_.end()) { + return errors::InvalidArgument("component ", component.scalar()(), " is invalid"); + } + int64 column_index = columns_index_[component.scalar()()]; + *shape = shapes_[column_index]; + *dtype = dtypes_[column_index]; + return Status::OK(); + } + + Status Read(const int64 start, const int64 stop, const Tensor& component, Tensor* value, Tensor* label) override { + const string& column = component.scalar()(); + avro::GenericDatum datum(reader_schema_); + + // Find the start sync point + int64 item_index_sync = 0; + for (size_t i = 0; i < positions_.size(); i++, item_index_sync += positions_[i].first) { + if (item_index_sync >= stop) { + continue; + } + if (item_index_sync + positions_[i].first <= start) { + continue; + } + // TODO: Avro is sync point partitioned and each block is very similiar to + // Row Group of parquet. Ideally each block should be cached with the hope + // that slicing and indexing will happend around the same block across multiple + // rows. Caching is not done yet. + + // Seek to sync + reader_->seek(positions_[i].second); + for (int64 item_index = item_index_sync; item_index < (item_index_sync + positions_[i].first) && item_index < stop; item_index++) { + // Read anyway + if (!reader_->read(datum)) { + return errors::Internal("unable to read record at: ", item_index); + } + // Assign only when in range + if (item_index >= start) { + const avro::GenericRecord& record = datum.value(); + const avro::GenericDatum& field = record.field(column); + switch(field.type()) { + case avro::AVRO_BOOL: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_INT: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_LONG: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_FLOAT: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_DOUBLE: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_STRING: + value->flat()(item_index - start) = field.value(); + break; + case avro::AVRO_BYTES: { + const std::vector& field_value = field.value>(); + value->flat()(item_index - start) = string((char *)&field_value[0], field_value.size()); + } + break; + case avro::AVRO_FIXED: { + const std::vector& field_value = field.value().value(); + value->flat()(item_index - start) = string((char *)&field_value[0], field_value.size()); + } + break; + case avro::AVRO_ENUM: + value->flat()(item_index - start) = field.value().symbol(); + break; + default: + return errors::InvalidArgument("unsupported data type: ", field.type()); + } + } + } + } + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("AvroIndexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + avro::ValidSchema reader_schema_; + std::unique_ptr reader_stream_; + std::unique_ptr> reader_; + std::vector> positions_; // pair + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::unordered_map columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("AvroIndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("AvroIndexableSpec").Device(DEVICE_CPU), + IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("AvroIndexablePartitions").Device(DEVICE_CPU), + IOIndexablePartitionsOp); +REGISTER_KERNEL_BUILDER(Name("AvroIndexableRead").Device(DEVICE_CPU), + IOIndexableReadOp); + } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/avro/ops/avro_ops.cc b/tensorflow_io/avro/ops/avro_ops.cc index 1c090dd83..4d5f0f22f 100644 --- a/tensorflow_io/avro/ops/avro_ops.cc +++ b/tensorflow_io/avro/ops/avro_ops.cc @@ -45,4 +45,54 @@ REGISTER_OP("ReadAvro") return Status::OK(); }); +REGISTER_OP("AvroIndexableInit") + .Input("input: string") + .Input("metadata: string") + .Output("resource: resource") + .Output("component: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("AvroIndexableSpec") + .Input("input: resource") + .Input("component: string") + .Output("shape: int64") + .Output("dtype: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("AvroIndexableRead") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("component: string") + .Output("value: dtype") + .Attr("filter: list(string) = []") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); + +REGISTER_OP("AvroIndexablePartitions") + .Input("input: resource") + .Output("partitions: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow_io/avro/python/ops/avro_ops.py b/tensorflow_io/avro/python/ops/avro_ops.py index 9a5278210..711d66ecb 100644 --- a/tensorflow_io/avro/python/ops/avro_ops.py +++ b/tensorflow_io/avro/python/ops/avro_ops.py @@ -17,10 +17,18 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow_io.core.python.ops import core_ops from tensorflow_io.core.python.ops import data_ops +warnings.warn( + "The tensorflow_io.avro.AvroDataset is " + "deprecated. Please look for tfio.IOTensor.from_avro " + "for reading Avro files into tensorflow.", + DeprecationWarning) + def list_avro_columns(filename, schema, **kwargs): """list_avro_columns""" if not tf.executing_eagerly(): diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index b3a785306..2cc14f179 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -30,7 +30,7 @@ cc_library( srcs = [ "kernels/dataset_ops.h", "kernels/io_interface.h", - "kernels/stream.h", + "kernels/io_stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/archive_kernels.cc b/tensorflow_io/core/kernels/archive_kernels.cc index 2149687b9..5b7358fd7 100644 --- a/tensorflow_io/core/kernels/archive_kernels.cc +++ b/tensorflow_io/core/kernels/archive_kernels.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" namespace tensorflow { namespace data { diff --git a/tensorflow_io/core/kernels/io_interface.h b/tensorflow_io/core/kernels/io_interface.h index 83a112328..fd6f1c697 100644 --- a/tensorflow_io/core/kernels/io_interface.h +++ b/tensorflow_io/core/kernels/io_interface.h @@ -25,9 +25,13 @@ class IOInterface : public ResourceBase { virtual Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) = 0; virtual Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype, bool label) = 0; + virtual Status Partitions(std::vector *partitions) { + // By default partitions is not implemented: Unimplemented + return errors::Unimplemented("Patitions"); + } virtual Status Components(Tensor* components) { // By default there is only one component: Unimplemented - return errors::Unimplemented("Component"); + return errors::Unimplemented("Components"); } virtual Status Extra(const Tensor& component, std::vector* extra) { // This is the chance to provide additional extra information which should be appended to extra. @@ -245,8 +249,15 @@ class IOInterfaceSpecOp : public OpKernel { OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); core::ScopedUnref unref(resource); + Status status; + + Tensor component_empty(DT_INT64, TensorShape({})); + component_empty.scalar()() = 0; const Tensor* component; - OP_REQUIRES_OK(context, context->input("component", &component)); + status = context->input("component", &component); + if (!status.ok()) { + component = &component_empty; + } PartialTensorShape shape; DataType dtype; @@ -262,7 +273,7 @@ class IOInterfaceSpecOp : public OpKernel { context->set_output(1, dtype_tensor); std::vector extra; - Status status = resource->Extra(*component, &extra); + status = resource->Extra(*component, &extra); if (!errors::IsUnimplemented(status)) { OP_REQUIRES_OK(context, status); for (size_t i = 0; i < extra.size(); i++) { @@ -304,8 +315,15 @@ class IOIterableNextOp : public OpKernel { OP_REQUIRES_OK(context, context->input("capacity", &capacity_tensor)); const int64 capacity = capacity_tensor->scalar()(); + Status status; + + Tensor component_empty(DT_INT64, TensorShape({})); + component_empty.scalar()() = 0; const Tensor* component; - OP_REQUIRES_OK(context, context->input("component", &component)); + status = context->input("component", &component); + if (!status.ok()) { + component = &component_empty; + } OP_REQUIRES(context, (capacity > 0), errors::InvalidArgument("capacity <= 0 is not supported: ", capacity)); @@ -399,8 +417,15 @@ class IOIndexableReadOp : public OpKernel { OP_REQUIRES_OK(context, context->input("stop", &stop_tensor)); int64 stop = stop_tensor->scalar()(); + Status status; + + Tensor component_empty(DT_INT64, TensorShape({})); + component_empty.scalar()() = 0; const Tensor* component; - OP_REQUIRES_OK(context, context->input("component", &component)); + status = context->input("component", &component); + if (!status.ok()) { + component = &component_empty; + } int64 output_index = 0; Tensor* value_tensor = nullptr; @@ -451,5 +476,28 @@ class IOMappingReadOp : public OpKernel { context->set_output(0, value); } }; +template +class IOIndexablePartitionsOp : public OpKernel { + public: + explicit IOIndexablePartitionsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + } + + void Compute(OpKernelContext* context) override { + Type* resource; + OP_REQUIRES_OK(context, GetResourceFromContext(context, "input", &resource)); + core::ScopedUnref unref(resource); + + std::vector partitions; + OP_REQUIRES_OK(context, resource->Partitions(&partitions)); + + Tensor partitions_tensor(DT_INT64, TensorShape({static_cast(partitions.size())})); + for (size_t i = 0; i < partitions.size(); i++) { + partitions_tensor.flat()(i) = partitions[i]; + } + + context->set_output(0, partitions_tensor); + } +}; } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/io_stream.h similarity index 100% rename from tensorflow_io/core/kernels/stream.h rename to tensorflow_io/core/kernels/io_stream.h diff --git a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py index bc8a803aa..55ae200bd 100644 --- a/tensorflow_io/core/python/ops/audio_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/audio_io_tensor_ops.py @@ -44,29 +44,27 @@ def __init__(self, filename, container=scope, shared_name="%s/%s" % (filename, uuid.uuid4().hex)) - shape, dtype, rate = core_ops.wav_indexable_spec(resource, component=0) + shape, dtype, rate = core_ops.wav_indexable_spec(resource) shape = tf.TensorShape(shape.numpy()) dtype = tf.as_dtype(dtype.numpy()) spec = tf.TensorSpec(shape, dtype) class _Function(object): - def __init__(self, func, shape, dtype, component=0): + def __init__(self, func, shape, dtype): self._func = func self._shape = tf.TensorShape([None]).concatenate(shape[1:]) self._dtype = dtype - self._component = component def __call__(self, resource, start, stop): return self._func( resource, start=start, stop=stop, - component=self._component, shape=self._shape, dtype=self._dtype) self._rate = rate.numpy() super(AudioIOTensor, self).__init__( spec, resource, _Function(core_ops.wav_indexable_read, spec.shape, spec.dtype), - internal=internal) + partitions=None, internal=internal) #============================================================================= # Accessors diff --git a/tensorflow_io/core/python/ops/avro_io_tensor_ops.py b/tensorflow_io/core/python/ops/avro_io_tensor_ops.py new file mode 100644 index 000000000..0c5eaa418 --- /dev/null +++ b/tensorflow_io/core/python/ops/avro_io_tensor_ops.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AvroIOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class AvroIOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access + """AvroIOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + schema, + capacity=None, + internal=False): + with tf.name_scope("AvroIOTensor") as scope: + metadata = ["schema: %s" % schema] + resource, columns = core_ops.avro_indexable_init( + filename, metadata=metadata, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + partitions = None + if capacity is not None: + partitions = core_ops.avro_indexable_partitions(resource) + partitions = partitions.numpy().tolist() + if capacity > 0: + partitions = [ + v for e in partitions for v in list( + [capacity] * (e // capacity) + [e % capacity])] + columns = [column.decode() for column in columns.numpy().tolist()] + spec = [] + for column in columns: + shape, dtype = core_ops.avro_indexable_spec(resource, column) + shape = tf.TensorShape(shape) + dtype = tf.as_dtype(dtype.numpy()) + spec.append(tf.TensorSpec(shape, dtype, column)) + spec = tuple(spec) + super(AvroIOTensor, self).__init__( + spec, columns, + resource, core_ops.avro_indexable_read, + partitions=partitions, internal=internal) diff --git a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py index 8dad814b3..239f87bd8 100644 --- a/tensorflow_io/core/python/ops/csv_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/csv_io_tensor_ops.py @@ -48,6 +48,7 @@ def __init__(self, super(CSVIOTensor, self).__init__( spec, columns, resource, core_ops.csv_indexable_read, + partitions=None, internal=internal) #============================================================================= @@ -76,4 +77,4 @@ def __call__(self, resource, start, stop): return io_tensor_ops.BaseIOTensor( spec, self._resource, _Function(core_ops.csv_indexable_read, spec, column), - internal=True) + partitions=None, internal=True) diff --git a/tensorflow_io/core/python/ops/feather_io_tensor_ops.py b/tensorflow_io/core/python/ops/feather_io_tensor_ops.py index 18f281598..6bf72bbbb 100644 --- a/tensorflow_io/core/python/ops/feather_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/feather_io_tensor_ops.py @@ -48,4 +48,5 @@ def __init__(self, super(FeatherIOTensor, self).__init__( spec, columns, resource, core_ops.feather_indexable_read, + partitions=None, internal=internal) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index e86cf60c4..446751745 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -27,6 +27,7 @@ from tensorflow_io.core.python.ops import prometheus_io_tensor_ops from tensorflow_io.core.python.ops import feather_io_tensor_ops from tensorflow_io.core.python.ops import csv_io_tensor_ops +from tensorflow_io.core.python.ops import avro_io_tensor_ops class IOTensor(io_tensor_ops._IOTensor): # pylint: disable=protected-access """IOTensor @@ -382,3 +383,23 @@ def from_csv(cls, """ with tf.name_scope(kwargs.get("name", "IOFromCSV")): return csv_io_tensor_ops.CSVIOTensor(filename, internal=True) + + @classmethod + def from_avro(cls, + filename, + schema, + **kwargs): + """Creates an `IOTensor` from an avro file. + + Args: + filename: A string, the filename of an avro file. + schema: A string, the schema of an avro file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromAvro")): + return avro_io_tensor_ops.AvroIOTensor( + filename, schema, internal=True, **kwargs) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index d084e9003..17ff0bde1 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -25,20 +25,52 @@ class _IOTensorMeta(property): """_IOTensorMeta is a decorator that is viewable to __repr__""" pass +class _IOTensorPartitionedFunction(object): + """PartitionedFunction will translate call to cached Function call""" + def __init__(self, func, partitions): + self._func = func + self._partitions = partitions + partitions_indices = tf.cumsum(partitions).numpy().tolist() + self._partitions_start = list([0] + partitions_indices[:-1]) + self._partitions_stop = partitions_indices + self._tensors = [None for _ in partitions] + + def __call__(self, resource, start, stop): + indices_start = tf.math.maximum(self._partitions_start, start) + indices_stop = tf.math.minimum(self._partitions_stop, stop) + indices_hit = tf.math.less(indices_start, indices_stop) + indices = tf.squeeze(tf.compat.v2.where(indices_hit), [1]) + items = [] + # TODO: change to tf.while_loop + for index in indices: + if self._tensors[index] is None: + self._tensors[index] = self._func( + resource, + self._partitions_start[index], + self._partitions_stop[index]) + slice_start = indices_start[index] - self._partitions_start[index] + slice_size = indices_stop[index] - indices_start[index] + item = tf.slice(self._tensors[index], [slice_start], [slice_size]) + items.append(item) + return tf.concat(items, axis=0) + class _IOTensorDataset(tf.compat.v2.data.Dataset): """_IOTensorDataset""" - def __init__(self, spec, resource, function): - start = 0 - stop = tf.nest.flatten(spec)[0].shape[0] - capacity = 4096 - entry_start = list(range(start, stop, capacity)) - entry_stop = entry_start[1:] + [stop] - + def __init__(self, spec, resource, function, partitions): + if partitions is None: + start = 0 + stop = tf.nest.flatten(spec)[0].shape[0] + capacity = 4096 + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] + else: + partitions = tf.cast(partitions, tf.int64) + entry_stop = tf.cumsum(partitions) + entry_start = tf.concat([[0], entry_stop[:-1]], axis=0) dataset = tf.compat.v2.data.Dataset.from_tensor_slices(( tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64))) - dataset = dataset.map(lambda start, stop: function(resource, start=start, stop=stop)) dataset = dataset.unbatch() @@ -107,9 +139,15 @@ def __init__(self, spec, resource, function, + partitions, internal=False): + # function used for dataset should not be partitioned. + self._dataset_function = function + if partitions is not None: + function = _IOTensorPartitionedFunction(function, partitions) self._resource = resource self._function = function + self._partitions = partitions super(BaseIOTensor, self).__init__( spec, internal=internal) @@ -145,7 +183,7 @@ def to_dataset(self): A `tf.data.Dataset` with value obtained from this `IOTensor`. """ return _IOTensorDataset( - self.spec, self._resource, self._function) + self.spec, self._resource, self._dataset_function, self._partitions) #============================================================================= # Indexing & Slicing @@ -195,7 +233,7 @@ def __call__(self, resource, start, stop): return BaseIOTensor(spec, self._resource, _Function(self._function, spec, size), - internal=True) + partitions=None, internal=True) #============================================================================= # Tensor Type Conversions @@ -246,7 +284,7 @@ def __call__(self, resource, start, stop): super(TensorIOTensor, self).__init__( tf.TensorSpec(tensor.shape, tensor.dtype), - tensor, _Function(tensor), internal=internal) + tensor, _Function(tensor), None, internal=internal) #============================================================================= # Tensor Type Conversions @@ -277,10 +315,12 @@ def __init__(self, columns, resource, function, + partitions, internal=False): self._columns = columns self._resource = resource self._function = function + self._partitions = partitions super(_TableIOTensor, self).__init__( spec, internal=internal) @@ -310,9 +350,9 @@ def __call__(self, resource, start, stop): resource, start=start, stop=stop, component=self._component, shape=self._shape, dtype=self._dtype) - + function = _Function(self._function, spec, column) return BaseIOTensor( - spec, self._resource, _Function(self._function, spec, column), + spec, self._resource, function, self._partitions, internal=True) #============================================================================= @@ -354,7 +394,7 @@ def __call__(self, resource, start, stop): return _IOTensorDataset( self.spec, self._resource, - _Function(self._function, self.spec, self.columns)) + _Function(self._function, self.spec, self.columns), self._partitions) class _CollectionIOTensor(_IOTensor): @@ -411,7 +451,7 @@ def __call__(self, resource, start, stop): return BaseIOTensor( spec, self._resource, _Function(self._function, spec, key), - internal=True) + partitions=None, internal=True) class _SeriesIOTensor(_IOTensor): """_SeriesIOTensor""" @@ -447,13 +487,19 @@ def __call__(self, resource, start, stop): def index(self): """The index column of the series""" return BaseIOTensor( - self.spec[0], self._resource, self._index_function, internal=True) + self.spec[0], + self._resource, + self._index_function, + partitions=None, internal=True) @property def value(self): """The value column of the series""" return BaseIOTensor( - self.spec[1], self._resource, self._value_function, internal=True) + self.spec[1], + self._resource, + self._value_function, + partitions=None, internal=True) class _KeyValueIOTensorDataset(tf.compat.v2.data.Dataset): """_KeyValueIOTensorDataset""" diff --git a/tensorflow_io/core/python/ops/json_io_tensor_ops.py b/tensorflow_io/core/python/ops/json_io_tensor_ops.py index 64ab43e4c..6bea4a21b 100644 --- a/tensorflow_io/core/python/ops/json_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/json_io_tensor_ops.py @@ -51,4 +51,5 @@ def __init__(self, super(JSONIOTensor, self).__init__( spec, columns, resource, core_ops.json_indexable_read, + partitions=None, internal=internal) diff --git a/tensorflow_io/core/python/ops/kafka_dataset_ops.py b/tensorflow_io/core/python/ops/kafka_dataset_ops.py index c352d97ba..e413615e0 100644 --- a/tensorflow_io/core/python/ops/kafka_dataset_ops.py +++ b/tensorflow_io/core/python/ops/kafka_dataset_ops.py @@ -62,7 +62,7 @@ def __init__(self, dataset = tf.compat.v2.data.Dataset.range(0, sys.maxsize, capacity) dataset = dataset.map( lambda i: core_ops.kafka_iterable_next( - resource, capacity, component=0, + resource, capacity, dtype=tf.string, shape=tf.TensorShape([None]))) dataset = dataset.apply( tf.data.experimental.take_while( diff --git a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py index 381c18817..b7fc2d2de 100644 --- a/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/kafka_io_tensor_ops.py @@ -58,11 +58,10 @@ def __init__(self, func, spec): def __call__(self, resource, start, stop): return self._func( resource, start=start, stop=stop, - component=0, shape=self._shape, dtype=self._dtype) self._iterable = iterable super(KafkaIOTensor, self).__init__( spec, resource, _Function(core_ops.kafka_indexable_read, spec), - internal=internal) + partitions=None, internal=internal) diff --git a/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py b/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py index 34482da2e..e5044ffbf 100644 --- a/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/lmdb_io_tensor_ops.py @@ -64,7 +64,7 @@ def __init__(self, func, shape, dtype): self._dtype = dtype def __call__(self, resource, capacity): return self._func( - resource, capacity, component=0, + resource, capacity, shape=self._shape, dtype=self._dtype) super(LMDBIOTensor, self).__init__( diff --git a/tensorflow_io/hdf5/kernels/hdf5_kernels.cc b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc index 0673b2440..2e02f3654 100644 --- a/tensorflow_io/hdf5/kernels/hdf5_kernels.cc +++ b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include #include diff --git a/tensorflow_io/json/kernels/json_kernels.cc b/tensorflow_io/json/kernels/json_kernels.cc index 98bca657b..3ac3f00fe 100644 --- a/tensorflow_io/json/kernels/json_kernels.cc +++ b/tensorflow_io/json/kernels/json_kernels.cc @@ -17,12 +17,11 @@ limitations under the License. #include #include "kernels/dataset_ops.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/platform/env.h" #include "include/json/json.h" #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "arrow/memory_pool.h" #include "arrow/json/reader.h" #include "arrow/table.h" diff --git a/tensorflow_io/kafka/ops/kafka_ops.cc b/tensorflow_io/kafka/ops/kafka_ops.cc index 6de3e420b..28d8733e0 100644 --- a/tensorflow_io/kafka/ops/kafka_ops.cc +++ b/tensorflow_io/kafka/ops/kafka_ops.cc @@ -33,7 +33,6 @@ REGISTER_OP("KafkaIndexableInit") REGISTER_OP("KafkaIndexableSpec") .Input("input: resource") - .Input("component: int64") .Output("shape: int64") .Output("dtype: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { @@ -46,7 +45,6 @@ REGISTER_OP("KafkaIndexableRead") .Input("input: resource") .Input("start: int64") .Input("stop: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") @@ -73,7 +71,6 @@ REGISTER_OP("KafkaIterableInit") REGISTER_OP("KafkaIterableNext") .Input("input: resource") .Input("capacity: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") diff --git a/tensorflow_io/lmdb/kernels/lmdb_kernels.cc b/tensorflow_io/lmdb/kernels/lmdb_kernels.cc index bd6f85a79..966ef7228 100644 --- a/tensorflow_io/lmdb/kernels/lmdb_kernels.cc +++ b/tensorflow_io/lmdb/kernels/lmdb_kernels.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" #include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include #include #include "lmdb.h" diff --git a/tensorflow_io/lmdb/ops/lmdb_ops.cc b/tensorflow_io/lmdb/ops/lmdb_ops.cc index c08b36041..367384ec4 100644 --- a/tensorflow_io/lmdb/ops/lmdb_ops.cc +++ b/tensorflow_io/lmdb/ops/lmdb_ops.cc @@ -56,7 +56,6 @@ REGISTER_OP("LMDBIterableInit") REGISTER_OP("LMDBIterableNext") .Input("input: resource") .Input("capacity: int64") - .Input("component: int64") .Output("value: dtype") .Attr("shape: shape") .Attr("dtype: type") diff --git a/tensorflow_io/text/kernels/csv_kernels.cc b/tensorflow_io/text/kernels/csv_kernels.cc index 86ee70a57..c2318af1b 100644 --- a/tensorflow_io/text/kernels/csv_kernels.cc +++ b/tensorflow_io/text/kernels/csv_kernels.cc @@ -14,10 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow_io/core/kernels/stream.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow_io/core/kernels/io_interface.h" -#include "tensorflow_io/core/kernels/stream.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "arrow/memory_pool.h" #include "arrow/csv/reader.h" #include "arrow/table.h" diff --git a/tensorflow_io/text/kernels/text_kernels.cc b/tensorflow_io/text/kernels/text_kernels.cc index 0481237a9..8f99fc468 100644 --- a/tensorflow_io/text/kernels/text_kernels.cc +++ b/tensorflow_io/text/kernels/text_kernels.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/io_stream.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "tensorflow_io/core/kernels/stream.h" namespace tensorflow { namespace data { diff --git a/tests/test_avro_eager.py b/tests/test_avro_eager.py index 172697239..a55e1c2f8 100644 --- a/tests/test_avro_eager.py +++ b/tests/test_avro_eager.py @@ -12,21 +12,22 @@ # License for the specific language governing permissions and limitations under # the License. # ============================================================================== -"""Tests for AvroDataset.""" +"""Tests for tfio.IOTensor.from_avro.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os +import numpy as np import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.avro as avro_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position def test_avro(): - """test_list_avro_columns.""" + """test_avro""" # The test.bin was created from avro/lang/c++/examples/datafile.cc. filename = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -39,30 +40,84 @@ def test_avro(): with open(schema_filename, 'r') as f: schema = f.read() - specs = avro_io.list_avro_columns(filename, schema) - assert specs["im"].dtype == tf.float64 - assert specs["re"].dtype == tf.float64 + avro = tfio.IOTensor.from_avro(filename, schema) + assert avro("im").dtype == tf.float64 + assert avro("im").shape == [100] + assert avro("re").dtype == tf.float64 + assert avro("re").shape == [100] - v0 = avro_io.read_avro(filename, schema, specs["im"]) - v1 = avro_io.read_avro(filename, schema, specs["re"]) - for i in range(100): - (im, re) = (i + 100, i * 100) - assert v0[i].numpy() == im - assert v1[i].numpy() == re + assert np.all( + avro("im").to_tensor().numpy() == [100.0 + i for i in range(100)]) + assert np.all( + avro("re").to_tensor().numpy() == [100.0 * i for i in range(100)]) - for capacity in [10, 20, 50, 100, 1000, 2000]: - dataset = tf.compat.v2.data.Dataset.zip( - ( - avro_io.AvroDataset(filename, schema, "im", capacity=capacity), - avro_io.AvroDataset(filename, schema, "re", capacity=capacity) - ) - ).apply(tf.data.experimental.unbatch()) + dataset = avro.to_dataset() + i = 0 + for v in dataset: + re, im = v + assert im.numpy() == 100.0 + i + assert re.numpy() == 100.0 * i + i += 1 + +def test_avro_partition(): + """test_avro_partition""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename + + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() + for capacity in [ + 1, 2, 3, + 11, 12, 13, + 50, 51, 100]: + avro = tfio.IOTensor.from_avro( + filename, schema, capacity=capacity) + assert np.all( + avro("im").to_tensor().numpy() == [100.0 + i for i in range(100)]) + assert np.all( + avro("re").to_tensor().numpy() == [100.0 * i for i in range(100)]) + for step in [ + 1, 2, 3, + 10, 11, 12, 13, + 50, 51, 52, 53]: + indices = list(range(0, 100, step)) + for start, stop in zip(indices, indices[1:] + [100]): + im_expected = [100.0 + i for i in range(start, stop)] + im_items = avro("im")[start:stop] + assert np.all(im_items.numpy() == im_expected) + + re_expected = [100.0 * i for i in range(start, stop)] + re_items = avro("re")[start:stop] + assert np.all(re_items.numpy() == re_expected) + +def test_avro_dataset_partition(): + """test_avro_dataset_partition""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename + + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() + for capacity in [1, 2, 3, 11, 12, 13, 50, 51, 100]: + avro = tfio.IOTensor.from_avro( + filename, schema, capacity=capacity) + dataset = avro.to_dataset() i = 0 - for vv in dataset: - v0, v1 = vv - (im, re) = (i + 100, i * 100) - assert v0.numpy() == im - assert v1.numpy() == re + for v in dataset: + re, im = v + assert im.numpy() == 100.0 + i + assert re.numpy() == 100.0 * i i += 1 if __name__ == "__main__":