From aada0a5b6d0b271406f4faddd40eec72cc3f294a Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 31 Jul 2019 04:25:38 +0000 Subject: [PATCH] Add read_avro and list_avro_columns for rework on Splittable Avro support This PR is part of the effort to rework on Dataset with large files reading into Tensors first to speed up performance. See 382 and 366 for related discussions. Summary: 1) read_avro is able to read a avro file within the range of [offset, offset+length] (Splittable) 2) we use primitive read_avro C++ ops to read in big chunks and then wire up with tf.data.Dataset 3) read_avro could be used in other places. 4) AvroDataset automatically find out the dtype in eager mode, in graph mode, user has to specify the dtype in kwargs. Signed-off-by: Yong Tang --- tensorflow_io/avro/BUILD | 2 +- tensorflow_io/avro/__init__.py | 6 + tensorflow_io/avro/kernels/avro_input.cc | 248 ----------------- tensorflow_io/avro/kernels/avro_kernels.cc | 292 +++++++++++++++++++++ tensorflow_io/avro/ops/avro_ops.cc | 32 +-- tensorflow_io/avro/python/ops/avro_ops.py | 91 ++++--- tensorflow_io/core/BUILD | 1 + tensorflow_io/core/kernels/stream.h | 71 +++++ tensorflow_io/core/python/ops/data_ops.py | 5 +- tests/test_avro.py | 89 +++---- tests/test_avro_eager.py | 69 +++++ 11 files changed, 541 insertions(+), 365 deletions(-) delete mode 100644 tensorflow_io/avro/kernels/avro_input.cc create mode 100644 tensorflow_io/avro/kernels/avro_kernels.cc create mode 100644 tensorflow_io/core/kernels/stream.h create mode 100644 tests/test_avro_eager.py diff --git a/tensorflow_io/avro/BUILD b/tensorflow_io/avro/BUILD index f9e288339..77e0d9952 100644 --- a/tensorflow_io/avro/BUILD +++ b/tensorflow_io/avro/BUILD @@ -10,7 +10,7 @@ load( cc_library( name = "avro_ops", srcs = [ - "kernels/avro_input.cc", + "kernels/avro_kernels.cc", "ops/avro_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/avro/__init__.py b/tensorflow_io/avro/__init__.py index 11e176484..ecf68a385 100644 --- a/tensorflow_io/avro/__init__.py +++ b/tensorflow_io/avro/__init__.py @@ -15,6 +15,8 @@ """Avro Dataset. @@AvroDataset +@@list_avro_columns +@@read_avro """ from __future__ import absolute_import @@ -22,11 +24,15 @@ from __future__ import print_function from tensorflow_io.avro.python.ops.avro_ops import AvroDataset +from tensorflow_io.avro.python.ops.avro_ops import list_avro_columns +from tensorflow_io.avro.python.ops.avro_ops import read_avro from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "AvroDataset", + "list_avro_columns", + "read_avro_", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/avro/kernels/avro_input.cc b/tensorflow_io/avro/kernels/avro_input.cc deleted file mode 100644 index 8f081165d..000000000 --- a/tensorflow_io/avro/kernels/avro_input.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "api/DataFile.hh" -#include "api/Compiler.hh" -#include "api/Generic.hh" -#include "api/Stream.hh" -#include - -namespace tensorflow { -namespace data { -static const size_t kAvroDataInputStreamBufferSize = 8192; -class AvroDataInputStream : public avro::InputStream { -public: - AvroDataInputStream(io::InputStreamInterface* s) - : stream_(s) {} - virtual ~AvroDataInputStream() {} - bool next(const uint8_t** data, size_t* len) override { - if (*len == 0) { - *len = kAvroDataInputStreamBufferSize; - } - if (*len <= prefix_.size()) { - buffer_ = prefix_.substr(0, *len); - prefix_ = prefix_.substr(*len); - } else { - int64 bytes_to_read = *len - prefix_.size(); - string chunk; - stream_->ReadNBytes(bytes_to_read, &chunk); - buffer_ = std::move(prefix_); - buffer_.append(chunk); - prefix_.clear(); - } - *data = (const uint8_t*)buffer_.data(); - *len = buffer_.size(); - byte_count_ += *len; - return (*len != 0); - } - void backup(size_t len) override { - string chunk = buffer_.substr(buffer_.size() - len); - chunk.append(prefix_); - prefix_ = std::move(chunk); - byte_count_ -= len; - } - void skip(size_t len) override { - if (len <= prefix_.size()) { - prefix_ = prefix_.substr(len); - } else { - int64 bytes_to_read = len - prefix_.size(); - stream_->SkipNBytes(bytes_to_read); - prefix_.clear(); - } - byte_count_ += len; - } - size_t byteCount() const override { - return byte_count_; - } -private: - io::InputStreamInterface* stream_; - size_t byte_count_ = 0; - string prefix_; - string buffer_; -}; - -class AvroInputStream{ -public: - explicit AvroInputStream(io::InputStreamInterface* s, const string& schema, const std::vector& columns) - : stream_(s) - , schema_(schema) - , columns_(columns) - , reader_(nullptr) { - } - - ~AvroInputStream() { - reader_.reset(nullptr); - } - - Status Open() { - string error; - std::istringstream ss(schema_); - if (!avro::compileJsonSchema(ss, reader_schema_, error)) { - return errors::Unimplemented("Avro schema error: ", error); - } - std::unique_ptr stream(static_cast(new AvroDataInputStream(stream_))); - reader_.reset(new avro::DataFileReader(std::move(stream), reader_schema_)); - return Status::OK(); - } - const avro::ValidSchema& ReaderSchema() const { - return reader_schema_; - } - bool ReadDatum(avro::GenericDatum& datum) { - return reader_->read(datum); - } -private: - io::InputStreamInterface* stream_; - string schema_; - std::vector columns_; - std::unique_ptr > reader_; - avro::ValidSchema reader_schema_; -}; - -class AvroInput: public FileInput { - public: - Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new AvroInputStream(s, schema(), columns())); - TF_RETURN_IF_ERROR(state.get()->Open()); - } - avro::GenericDatum datum(state.get()->ReaderSchema()); - while ((*record_read) < record_to_read && state.get()->ReadDatum(datum)) { - const avro::GenericRecord& record = datum.value(); - if (*record_read == 0) { - out_tensors->clear(); - // Let's allocate enough space for Tensor, if more than read then slice. - for (size_t i = 0; i < columns().size(); i++) { - const string& column = columns()[i]; - const avro::GenericDatum& field = record.field(column); - DataType dtype; - switch (field.type()) { - case avro::AVRO_BOOL: - dtype = DT_BOOL; - break; - case avro::AVRO_INT: - dtype = DT_INT32; - break; - case avro::AVRO_LONG: - dtype = DT_INT64; - break; - case avro::AVRO_FLOAT: - dtype = DT_FLOAT; - break; - case avro::AVRO_DOUBLE: - dtype = DT_DOUBLE; - break; - case avro::AVRO_STRING: - dtype = DT_STRING; - break; - case avro::AVRO_BYTES: - dtype = DT_STRING; - break; - case avro::AVRO_FIXED: - dtype = DT_STRING; - break; - case avro::AVRO_ENUM: - dtype = DT_STRING; - break; - default: - return errors::InvalidArgument("unsupported data type: ", field.type()); - } - Tensor tensor(ctx->allocator({}), dtype, {record_to_read}); - out_tensors->emplace_back(std::move(tensor)); - } - } - for (size_t i = 0; i < columns().size(); i++) { - const string& column = columns()[i]; - const avro::GenericDatum& field = record.field(column); - switch (field.type()) { - case avro::AVRO_BOOL: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_INT: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_LONG: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_FLOAT: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_DOUBLE: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_STRING: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_BYTES: - { - const std::vector& value = field.value>(); - string v; - if (value.size() > 0) { - v.resize(value.size()); - memcpy(&v[0], &value[0], value.size()); - } - } - break; - case avro::AVRO_FIXED: - { - const std::vector& value = field.value().value(); - string v; - if (value.size() > 0) { - v.resize(value.size()); - memcpy(&v[0], &value[0], value.size()); - } - } - break; - case avro::AVRO_ENUM: - ((*out_tensors)[i]).flat()(*record_read) = field.value().symbol(); - break; - default: - return errors::InvalidArgument("unsupported data type: ", field.type()); - } - } - (*record_read)++; - } - // Slice if needed - if (*record_read < record_to_read) { - if (*record_read == 0) { - out_tensors->clear(); - } - for (size_t i = 0; i < out_tensors->size(); i++) { - Tensor tensor = (*out_tensors)[i].Slice(0, *record_read); - (*out_tensors)[i] = std::move(tensor); - } - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(AvroInput, "tensorflow::data::AvroInput"); - -REGISTER_KERNEL_BUILDER(Name("AvroInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("AvroDataset").Device(DEVICE_CPU), - FileInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/avro/kernels/avro_kernels.cc b/tensorflow_io/avro/kernels/avro_kernels.cc new file mode 100644 index 000000000..5340cf91d --- /dev/null +++ b/tensorflow_io/avro/kernels/avro_kernels.cc @@ -0,0 +1,292 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "api/DataFile.hh" +#include "api/Compiler.hh" +#include "api/Generic.hh" +#include "api/Stream.hh" +#include "api/Validator.hh" + +namespace tensorflow { +namespace data { +namespace { + +static const size_t kAvroInputStreamBufferSize = 8192; +class AvroInputStream : public avro::SeekableInputStream { +public: + AvroInputStream(tensorflow::RandomAccessFile* file) + : file_(file) { + } + virtual ~AvroInputStream() {} + bool next(const uint8_t** data, size_t* len) override { + if (*len == 0) { + *len = kAvroInputStreamBufferSize; + } + if (buffer_.size() < *len) { + buffer_.resize(*len); + } + StringPiece result; + Status status = file_->Read(byte_count_, *len, &result, &buffer_[0]); + *data = (const uint8_t*)buffer_.data(); + *len = result.size(); + byte_count_ += *len; + return (*len != 0); + } + void backup(size_t len) override { + byte_count_ -= len; + } + void skip(size_t len) override { + byte_count_ += len; + } + void seek(int64_t position) override { + byte_count_ = position; + } + size_t byteCount() const override { + return byte_count_; + } +private: + tensorflow::RandomAccessFile* file_; + string buffer_; + uint64 byte_count_ = 0; +}; + +class ListAvroColumnsOp : public OpKernel { + public: + explicit ListAvroColumnsOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + const Tensor& schema_tensor = context->input(1); + const string& schema = schema_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(2); + const string& memory = memory_tensor.scalar()(); + + avro::ValidSchema reader_schema; + + string error; + std::istringstream ss(schema); + OP_REQUIRES(context, avro::compileJsonSchema(ss, reader_schema, error), errors::Unimplemented("Avro schema error: ", error)); + + avro::GenericDatum datum(reader_schema.root()); + + std::vector columns; + std::vector dtypes; + columns.reserve(reader_schema.root()->names()); + dtypes.reserve(reader_schema.root()->names()); + + const avro::GenericRecord& record = datum.value(); + for (int i = 0; i < reader_schema.root()->names(); i++) { + const avro::GenericDatum& field = record.field(reader_schema.root()->nameAt(i)); + string dtype; + switch(field.type()) { + case avro::AVRO_BOOL: + dtype = "bool"; + break; + case avro::AVRO_INT: + dtype = "int32"; + break; + case avro::AVRO_LONG: + dtype = "int64"; + break; + case avro::AVRO_FLOAT: + dtype = "float"; + break; + case avro::AVRO_DOUBLE: + dtype = "double"; + break; + case avro::AVRO_STRING: + dtype = "string"; + break; + case avro::AVRO_BYTES: + dtype = "string"; + break; + case avro::AVRO_FIXED: + dtype = "string"; + break; + case avro::AVRO_ENUM: + dtype = "string"; + break; + default: + break; + } + if (dtype == "") { + continue; + } + columns.emplace_back(reader_schema.root()->nameAt(i)); + dtypes.emplace_back(dtype); + } + + TensorShape output_shape = filename_tensor.shape(); + output_shape.AddDim(columns.size()); + + Tensor* columns_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &columns_tensor)); + Tensor* dtypes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor)); + + output_shape.AddDim(1); + + for (size_t i = 0; i < columns.size(); i++) { + columns_tensor->flat()(i) = columns[i]; + dtypes_tensor->flat()(i) = dtypes[i]; + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +class ReadAvroOp : public OpKernel { + public: + explicit ReadAvroOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& schema_tensor = context->input(1); + const string& schema = schema_tensor.scalar()(); + + const Tensor& column_tensor = context->input(2); + const string& column = column_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(3); + const string& memory = memory_tensor.scalar()(); + + const Tensor& offset_tensor = context->input(4); + const int64 offset = offset_tensor.scalar()(); + + const Tensor& length_tensor = context->input(5); + int64 length = length_tensor.scalar()(); + + avro::ValidSchema reader_schema; + + string error; + std::istringstream ss(schema); + OP_REQUIRES(context, avro::compileJsonSchema(ss, reader_schema, error), errors::Unimplemented("Avro schema error: ", error)); + + + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + if (length < 0) { + length = size - offset; + } + + avro::GenericDatum datum(reader_schema); + + std::unique_ptr stream(new AvroInputStream(file.get())); + std::unique_ptr> reader(new avro::DataFileReader(std::move(stream), reader_schema)); + + if (offset != 0) { + reader->sync(offset); + } + + #define BOOL_VALUE records.push_back(field.value()) + #define INT32_VALUE records.emplace_back(field.value()) + #define INT64_VALUE records.emplace_back(field.value()) + #define FLOAT_VALUE records.emplace_back(field.value()) + #define DOUBLE_VALUE records.emplace_back(field.value()) + #define STRING_VALUE records.emplace_back(field.value()) + #define BYTES_VALUE { \ + const std::vector& value = field.value>(); \ + string v; \ + if (value.size() > 0) { \ + v.resize(value.size()); \ + memcpy(&v[0], &value[0], value.size()); \ + } \ + records.emplace_back(v); \ + } + #define FIXED_VALUE { \ + const std::vector& value = field.value().value(); \ + string v; \ + if (value.size() > 0) { \ + v.resize(value.size()); \ + memcpy(&v[0], &value[0], value.size()); \ + } \ + records.emplace_back(v); \ + } + #define ENUM_VALUE records.emplace_back(field.value().symbol()) + + #define PROCESS_RECORD(TYPE, ATYPE, VALUE) { \ + std::vector records; \ + while (!reader->pastSync(offset + length) && reader->read(datum)) { \ + const avro::GenericRecord& record = datum.value(); \ + const avro::GenericDatum& field = record.field(column); \ + VALUE; \ + } \ + Tensor* output_tensor; \ + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({static_cast(records.size())}), &output_tensor)); \ + for (size_t i = 0; i < records.size(); i++) { \ + output_tensor->flat()(i) = std::move(records[i]); \ + } \ + } + switch(datum.value().field(column).type()) { + case avro::AVRO_BOOL: + PROCESS_RECORD(bool, bool, BOOL_VALUE); + break; + case avro::AVRO_INT: + PROCESS_RECORD(int32, int32_t, INT32_VALUE); + break; + case avro::AVRO_LONG: + PROCESS_RECORD(int64, int64_t, INT64_VALUE); + break; + case avro::AVRO_FLOAT: + PROCESS_RECORD(float, float, FLOAT_VALUE); + break; + case avro::AVRO_DOUBLE: + PROCESS_RECORD(double, double, DOUBLE_VALUE); + break; + case avro::AVRO_STRING: + PROCESS_RECORD(string, string, STRING_VALUE); + break; + case avro::AVRO_BYTES: + PROCESS_RECORD(string, string, BYTES_VALUE); + break; + case avro::AVRO_FIXED: + PROCESS_RECORD(string, string, FIXED_VALUE); + break; + case avro::AVRO_ENUM: + PROCESS_RECORD(string, string, ENUM_VALUE); + break; + default: + OP_REQUIRES(context, false, errors::InvalidArgument("unsupported data type: ", datum.value().field(column).type())); + } + + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ListAvroColumns").Device(DEVICE_CPU), + ListAvroColumnsOp); +REGISTER_KERNEL_BUILDER(Name("ReadAvro").Device(DEVICE_CPU), + ReadAvroOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/avro/ops/avro_ops.cc b/tensorflow_io/avro/ops/avro_ops.cc index acd720d3f..1c090dd83 100644 --- a/tensorflow_io/avro/ops/avro_ops.cc +++ b/tensorflow_io/avro/ops/avro_ops.cc @@ -19,27 +19,29 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("AvroInput") - .Input("source: string") - .Output("handle: variant") - .Attr("filters: list(string) = []") - .Attr("columns: list(string) = []") - .Attr("schema: string = ''") +REGISTER_OP("ListAvroColumns") + .Input("filename: string") + .Input("schema: string") + .Input("memory: string") + .Output("columns: string") + .Output("dtypes: string") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); -REGISTER_OP("AvroDataset") - .Input("input: T") - .Input("batch: int64") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("T: {string, variant} = DT_VARIANT") - .SetIsStateful() +REGISTER_OP("ReadAvro") + .Input("filename: string") + .Input("schema: string") + .Input("column: string") + .Input("memory: string") + .Input("offset: int64") + .Input("length: int64") + .Attr("dtype: type") + .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); + c->set_output(0, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); diff --git a/tensorflow_io/avro/python/ops/avro_ops.py b/tensorflow_io/avro/python/ops/avro_ops.py index ce92389fb..9a5278210 100644 --- a/tensorflow_io/avro/python/ops/avro_ops.py +++ b/tensorflow_io/avro/python/ops/avro_ops.py @@ -18,53 +18,64 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.compat.v1 import data -from tensorflow_io.core.python.ops import core_ops as avro_ops +from tensorflow_io.core.python.ops import core_ops +from tensorflow_io.core.python.ops import data_ops -class AvroDataset(data.Dataset): +def list_avro_columns(filename, schema, **kwargs): + """list_avro_columns""" + if not tf.executing_eagerly(): + raise NotImplementedError("list_avro_columns only support eager mode") + memory = kwargs.get("memory", "") + columns, dtypes = core_ops.list_avro_columns( + filename, schema=schema, memory=memory) + entries = zip(tf.unstack(columns), tf.unstack(dtypes)) + return dict([(column.numpy().decode(), tf.TensorSpec( + tf.TensorShape([None]), + dtype.numpy().decode(), + column.numpy().decode())) for ( + column, dtype) in entries]) + +def read_avro(filename, schema, column, **kwargs): + """read_avro""" + memory = kwargs.get("memory", "") + offset = kwargs.get("offset", 0) + length = kwargs.get("length", -1) + return core_ops.read_avro( + filename, schema, column.name, memory=memory, + offset=offset, length=length, dtype=column.dtype) + +class AvroDataset(data_ops.BaseDataset): """A Avro Dataset that reads the avro file.""" - def __init__(self, filenames, columns, schema, dtypes=None, batch=None): + def __init__(self, filename, schema, column, **kwargs): """Create a `AvroDataset`. Args: - filenames: A 0-D or 1-D `tf.string` tensor containing one or more - filenames. - columns: A 0-D or 1-D `tf.int32` tensor containing the columns to extract. + filenames: A string containing one or more filename. schema: A string containing the avro schema. - dtypes: A tuple of `tf.DType` objects representing the types of the - columns returned. + column: A string containing the column to extract. """ - self._data_input = avro_ops.avro_input( - filenames, ["none", "gz"], columns=columns, schema=schema) - self._columns = columns - self._schema = schema - self._dtypes = dtypes - self._batch = 0 if batch is None else batch - super(AvroDataset, self).__init__() - - def _inputs(self): - return [] - - def _as_variant_tensor(self): - return avro_ops.avro_dataset( - self._data_input, - self._batch, - output_types=self.output_types, - output_shapes=self.output_shapes) - - @property - def output_classes(self): - return tuple([tf.Tensor for _ in self._columns]) + if not tf.executing_eagerly(): + dtype = kwargs.get("dtype") + else: + columns = list_avro_columns(filename, schema) + dtype = columns[column].dtype + shape = tf.TensorShape([None]) - @property - def output_shapes(self): - return tuple( - [tf.TensorShape([]) for _ in self._columns] - ) if self._batch is None else tuple( - [tf.TensorShape([None]) for _ in self._columns] - ) + filesize = tf.io.gfile.GFile(filename).size() + # capacity is the rough length for each split + capacity = kwargs.get("capacity", 65536) + entry_offset = list(range(0, filesize, capacity)) + entry_length = [min(capacity, filesize - offset) for offset in entry_offset] + dataset = data_ops.BaseDataset.from_tensor_slices( + ( + tf.constant(entry_offset, tf.int64), + tf.constant(entry_length, tf.int64) + ) + ).map(lambda offset, length: core_ops.read_avro( + filename, schema, column, memory="", + offset=offset, length=length, dtype=dtype)) + self._dataset = dataset - @property - def output_types(self): - return self._dtypes + super(AvroDataset, self).__init__( + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 25c62ec34..07bdc0b36 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -28,6 +28,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/stream.h new file mode 100644 index 000000000..e812babf4 --- /dev/null +++ b/tensorflow_io/core/kernels/stream.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace data { + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) + : file_(nullptr) + , size_(optional_memory_size) + , buff_((const char *)(optional_memory_buff)) + , size_status_(Status::OK()) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, &buff_[offset], bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + uint64 size_; + const char *buff_; + Status size_status_; +}; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index bbd3b5d8b..40107037d 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -58,9 +58,8 @@ def _apply_fn(dataset): class BaseDataset(tf.compat.v2.data.Dataset): """A Base Dataset""" - def __init__(self, variant, batch, dtypes, shapes): + def __init__(self, variant, dtypes, shapes): """Create a Base Dataset.""" - self._batch = 0 if batch is None else batch self._dtypes = dtypes self._shapes = shapes super(BaseDataset, self).__init__(variant) @@ -93,4 +92,4 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._data_input, self._batch, output_types=self._dtypes, - output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + output_shapes=self._shapes), self._dtypes, self._shapes) diff --git a/tests/test_avro.py b/tests/test_avro.py index 6024b11ca..d134ee26c 100644 --- a/tests/test_avro.py +++ b/tests/test_avro.py @@ -19,74 +19,47 @@ from __future__ import print_function import os +import numpy as np +import pytest import tensorflow as tf tf.compat.v1.disable_eager_execution() - -from tensorflow import dtypes # pylint: disable=wrong-import-position -from tensorflow import errors # pylint: disable=wrong-import-position -from tensorflow import test # pylint: disable=wrong-import-position -from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position - import tensorflow_io.avro as avro_io # pylint: disable=wrong-import-position -class AvroDatasetTest(test.TestCase): - """AvroDatasetTest""" - - def test_avro_dataset(self): - """Test case for AvroDataset.""" - # The test.bin was created from avro/lang/c++/examples/datafile.cc. - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_avro", "test.bin") - filename = "file://" + filename - - schema_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_avro", "cpx.json") - with open(schema_filename, 'r') as f: - schema = f.read() - - columns = ['im', 're'] - output_types = (dtypes.float64, dtypes.float64) - num_repeats = 2 +def test_avro_dataset(): + """Test case for AvroDataset.""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename - dataset = avro_io.AvroDataset( - [filename], columns, schema, output_types).repeat(num_repeats) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() - with self.test_session() as sess: - sess.run(init_op) - for _ in range(num_repeats): - for i in range(100): - (im, re) = (i + 100, i * 100) - vv = sess.run(get_next) - self.assertAllClose((im, re), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + num_repeats = 2 + dataset = tf.compat.v2.data.Dataset.zip( + ( + avro_io.AvroDataset(filename, schema, "im", dtype=tf.float64), + avro_io.AvroDataset(filename, schema, "re", dtype=tf.float64) + )).repeat(num_repeats).apply(tf.data.experimental.unbatch()) - dataset = avro_io.AvroDataset( - [filename, filename], columns, schema, output_types, batch=3) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + init_op = iterator.initializer + get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - for ii in range(0, 198, 3): - i = ii % 100 - (im, re) = ( - [i + 100, ((i + 1) % 100) + 100, ((i + 2) % 100) + 100], - [i * 100, ((i + 1) % 100) * 100, ((i + 2) % 100) * 100]) + with tf.compat.v1.Session() as sess: + sess.run(init_op) + for _ in range(num_repeats): + for i in range(100): + (im, re) = (i + 100, i * 100) vv = sess.run(get_next) - self.assertAllClose((im, re), vv) - (im, re) = ([198, 199], [9800, 9900]) - vv = sess.run(get_next) - self.assertAllClose((im, re), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + np.allclose((im, re), vv) + with pytest.raises(tf.errors.OutOfRangeError): + sess.run(get_next) if __name__ == "__main__": test.main() diff --git a/tests/test_avro_eager.py b/tests/test_avro_eager.py new file mode 100644 index 000000000..172697239 --- /dev/null +++ b/tests/test_avro_eager.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for AvroDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.avro as avro_io # pylint: disable=wrong-import-position + +def test_avro(): + """test_list_avro_columns.""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename + + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() + + specs = avro_io.list_avro_columns(filename, schema) + assert specs["im"].dtype == tf.float64 + assert specs["re"].dtype == tf.float64 + + v0 = avro_io.read_avro(filename, schema, specs["im"]) + v1 = avro_io.read_avro(filename, schema, specs["re"]) + for i in range(100): + (im, re) = (i + 100, i * 100) + assert v0[i].numpy() == im + assert v1[i].numpy() == re + + for capacity in [10, 20, 50, 100, 1000, 2000]: + dataset = tf.compat.v2.data.Dataset.zip( + ( + avro_io.AvroDataset(filename, schema, "im", capacity=capacity), + avro_io.AvroDataset(filename, schema, "re", capacity=capacity) + ) + ).apply(tf.data.experimental.unbatch()) + i = 0 + for vv in dataset: + v0, v1 = vv + (im, re) = (i + 100, i * 100) + assert v0.numpy() == im + assert v1.numpy() == re + i += 1 + +if __name__ == "__main__": + test.main()