diff --git a/tensorflow_io/avro/BUILD b/tensorflow_io/avro/BUILD index f9e288339..77e0d9952 100644 --- a/tensorflow_io/avro/BUILD +++ b/tensorflow_io/avro/BUILD @@ -10,7 +10,7 @@ load( cc_library( name = "avro_ops", srcs = [ - "kernels/avro_input.cc", + "kernels/avro_kernels.cc", "ops/avro_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/avro/__init__.py b/tensorflow_io/avro/__init__.py index 11e176484..ecf68a385 100644 --- a/tensorflow_io/avro/__init__.py +++ b/tensorflow_io/avro/__init__.py @@ -15,6 +15,8 @@ """Avro Dataset. @@AvroDataset +@@list_avro_columns +@@read_avro """ from __future__ import absolute_import @@ -22,11 +24,15 @@ from __future__ import print_function from tensorflow_io.avro.python.ops.avro_ops import AvroDataset +from tensorflow_io.avro.python.ops.avro_ops import list_avro_columns +from tensorflow_io.avro.python.ops.avro_ops import read_avro from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "AvroDataset", + "list_avro_columns", + "read_avro_", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/avro/kernels/avro_input.cc b/tensorflow_io/avro/kernels/avro_input.cc deleted file mode 100644 index 8f081165d..000000000 --- a/tensorflow_io/avro/kernels/avro_input.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "api/DataFile.hh" -#include "api/Compiler.hh" -#include "api/Generic.hh" -#include "api/Stream.hh" -#include - -namespace tensorflow { -namespace data { -static const size_t kAvroDataInputStreamBufferSize = 8192; -class AvroDataInputStream : public avro::InputStream { -public: - AvroDataInputStream(io::InputStreamInterface* s) - : stream_(s) {} - virtual ~AvroDataInputStream() {} - bool next(const uint8_t** data, size_t* len) override { - if (*len == 0) { - *len = kAvroDataInputStreamBufferSize; - } - if (*len <= prefix_.size()) { - buffer_ = prefix_.substr(0, *len); - prefix_ = prefix_.substr(*len); - } else { - int64 bytes_to_read = *len - prefix_.size(); - string chunk; - stream_->ReadNBytes(bytes_to_read, &chunk); - buffer_ = std::move(prefix_); - buffer_.append(chunk); - prefix_.clear(); - } - *data = (const uint8_t*)buffer_.data(); - *len = buffer_.size(); - byte_count_ += *len; - return (*len != 0); - } - void backup(size_t len) override { - string chunk = buffer_.substr(buffer_.size() - len); - chunk.append(prefix_); - prefix_ = std::move(chunk); - byte_count_ -= len; - } - void skip(size_t len) override { - if (len <= prefix_.size()) { - prefix_ = prefix_.substr(len); - } else { - int64 bytes_to_read = len - prefix_.size(); - stream_->SkipNBytes(bytes_to_read); - prefix_.clear(); - } - byte_count_ += len; - } - size_t byteCount() const override { - return byte_count_; - } -private: - io::InputStreamInterface* stream_; - size_t byte_count_ = 0; - string prefix_; - string buffer_; -}; - -class AvroInputStream{ -public: - explicit AvroInputStream(io::InputStreamInterface* s, const string& schema, const std::vector& columns) - : stream_(s) - , schema_(schema) - , columns_(columns) - , reader_(nullptr) { - } - - ~AvroInputStream() { - reader_.reset(nullptr); - } - - Status Open() { - string error; - std::istringstream ss(schema_); - if (!avro::compileJsonSchema(ss, reader_schema_, error)) { - return errors::Unimplemented("Avro schema error: ", error); - } - std::unique_ptr stream(static_cast(new AvroDataInputStream(stream_))); - reader_.reset(new avro::DataFileReader(std::move(stream), reader_schema_)); - return Status::OK(); - } - const avro::ValidSchema& ReaderSchema() const { - return reader_schema_; - } - bool ReadDatum(avro::GenericDatum& datum) { - return reader_->read(datum); - } -private: - io::InputStreamInterface* stream_; - string schema_; - std::vector columns_; - std::unique_ptr > reader_; - avro::ValidSchema reader_schema_; -}; - -class AvroInput: public FileInput { - public: - Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new AvroInputStream(s, schema(), columns())); - TF_RETURN_IF_ERROR(state.get()->Open()); - } - avro::GenericDatum datum(state.get()->ReaderSchema()); - while ((*record_read) < record_to_read && state.get()->ReadDatum(datum)) { - const avro::GenericRecord& record = datum.value(); - if (*record_read == 0) { - out_tensors->clear(); - // Let's allocate enough space for Tensor, if more than read then slice. - for (size_t i = 0; i < columns().size(); i++) { - const string& column = columns()[i]; - const avro::GenericDatum& field = record.field(column); - DataType dtype; - switch (field.type()) { - case avro::AVRO_BOOL: - dtype = DT_BOOL; - break; - case avro::AVRO_INT: - dtype = DT_INT32; - break; - case avro::AVRO_LONG: - dtype = DT_INT64; - break; - case avro::AVRO_FLOAT: - dtype = DT_FLOAT; - break; - case avro::AVRO_DOUBLE: - dtype = DT_DOUBLE; - break; - case avro::AVRO_STRING: - dtype = DT_STRING; - break; - case avro::AVRO_BYTES: - dtype = DT_STRING; - break; - case avro::AVRO_FIXED: - dtype = DT_STRING; - break; - case avro::AVRO_ENUM: - dtype = DT_STRING; - break; - default: - return errors::InvalidArgument("unsupported data type: ", field.type()); - } - Tensor tensor(ctx->allocator({}), dtype, {record_to_read}); - out_tensors->emplace_back(std::move(tensor)); - } - } - for (size_t i = 0; i < columns().size(); i++) { - const string& column = columns()[i]; - const avro::GenericDatum& field = record.field(column); - switch (field.type()) { - case avro::AVRO_BOOL: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_INT: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_LONG: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_FLOAT: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_DOUBLE: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_STRING: - ((*out_tensors)[i]).flat()(*record_read) = field.value(); - break; - case avro::AVRO_BYTES: - { - const std::vector& value = field.value>(); - string v; - if (value.size() > 0) { - v.resize(value.size()); - memcpy(&v[0], &value[0], value.size()); - } - } - break; - case avro::AVRO_FIXED: - { - const std::vector& value = field.value().value(); - string v; - if (value.size() > 0) { - v.resize(value.size()); - memcpy(&v[0], &value[0], value.size()); - } - } - break; - case avro::AVRO_ENUM: - ((*out_tensors)[i]).flat()(*record_read) = field.value().symbol(); - break; - default: - return errors::InvalidArgument("unsupported data type: ", field.type()); - } - } - (*record_read)++; - } - // Slice if needed - if (*record_read < record_to_read) { - if (*record_read == 0) { - out_tensors->clear(); - } - for (size_t i = 0; i < out_tensors->size(); i++) { - Tensor tensor = (*out_tensors)[i].Slice(0, *record_read); - (*out_tensors)[i] = std::move(tensor); - } - } - return Status::OK(); - } - Status FromStream(io::InputStreamInterface* s) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(AvroInput, "tensorflow::data::AvroInput"); - -REGISTER_KERNEL_BUILDER(Name("AvroInput").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("AvroDataset").Device(DEVICE_CPU), - FileInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/avro/kernels/avro_kernels.cc b/tensorflow_io/avro/kernels/avro_kernels.cc new file mode 100644 index 000000000..5340cf91d --- /dev/null +++ b/tensorflow_io/avro/kernels/avro_kernels.cc @@ -0,0 +1,292 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/stream.h" +#include "api/DataFile.hh" +#include "api/Compiler.hh" +#include "api/Generic.hh" +#include "api/Stream.hh" +#include "api/Validator.hh" + +namespace tensorflow { +namespace data { +namespace { + +static const size_t kAvroInputStreamBufferSize = 8192; +class AvroInputStream : public avro::SeekableInputStream { +public: + AvroInputStream(tensorflow::RandomAccessFile* file) + : file_(file) { + } + virtual ~AvroInputStream() {} + bool next(const uint8_t** data, size_t* len) override { + if (*len == 0) { + *len = kAvroInputStreamBufferSize; + } + if (buffer_.size() < *len) { + buffer_.resize(*len); + } + StringPiece result; + Status status = file_->Read(byte_count_, *len, &result, &buffer_[0]); + *data = (const uint8_t*)buffer_.data(); + *len = result.size(); + byte_count_ += *len; + return (*len != 0); + } + void backup(size_t len) override { + byte_count_ -= len; + } + void skip(size_t len) override { + byte_count_ += len; + } + void seek(int64_t position) override { + byte_count_ = position; + } + size_t byteCount() const override { + return byte_count_; + } +private: + tensorflow::RandomAccessFile* file_; + string buffer_; + uint64 byte_count_ = 0; +}; + +class ListAvroColumnsOp : public OpKernel { + public: + explicit ListAvroColumnsOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + const Tensor& schema_tensor = context->input(1); + const string& schema = schema_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(2); + const string& memory = memory_tensor.scalar()(); + + avro::ValidSchema reader_schema; + + string error; + std::istringstream ss(schema); + OP_REQUIRES(context, avro::compileJsonSchema(ss, reader_schema, error), errors::Unimplemented("Avro schema error: ", error)); + + avro::GenericDatum datum(reader_schema.root()); + + std::vector columns; + std::vector dtypes; + columns.reserve(reader_schema.root()->names()); + dtypes.reserve(reader_schema.root()->names()); + + const avro::GenericRecord& record = datum.value(); + for (int i = 0; i < reader_schema.root()->names(); i++) { + const avro::GenericDatum& field = record.field(reader_schema.root()->nameAt(i)); + string dtype; + switch(field.type()) { + case avro::AVRO_BOOL: + dtype = "bool"; + break; + case avro::AVRO_INT: + dtype = "int32"; + break; + case avro::AVRO_LONG: + dtype = "int64"; + break; + case avro::AVRO_FLOAT: + dtype = "float"; + break; + case avro::AVRO_DOUBLE: + dtype = "double"; + break; + case avro::AVRO_STRING: + dtype = "string"; + break; + case avro::AVRO_BYTES: + dtype = "string"; + break; + case avro::AVRO_FIXED: + dtype = "string"; + break; + case avro::AVRO_ENUM: + dtype = "string"; + break; + default: + break; + } + if (dtype == "") { + continue; + } + columns.emplace_back(reader_schema.root()->nameAt(i)); + dtypes.emplace_back(dtype); + } + + TensorShape output_shape = filename_tensor.shape(); + output_shape.AddDim(columns.size()); + + Tensor* columns_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &columns_tensor)); + Tensor* dtypes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor)); + + output_shape.AddDim(1); + + for (size_t i = 0; i < columns.size(); i++) { + columns_tensor->flat()(i) = columns[i]; + dtypes_tensor->flat()(i) = dtypes[i]; + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +class ReadAvroOp : public OpKernel { + public: + explicit ReadAvroOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& schema_tensor = context->input(1); + const string& schema = schema_tensor.scalar()(); + + const Tensor& column_tensor = context->input(2); + const string& column = column_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(3); + const string& memory = memory_tensor.scalar()(); + + const Tensor& offset_tensor = context->input(4); + const int64 offset = offset_tensor.scalar()(); + + const Tensor& length_tensor = context->input(5); + int64 length = length_tensor.scalar()(); + + avro::ValidSchema reader_schema; + + string error; + std::istringstream ss(schema); + OP_REQUIRES(context, avro::compileJsonSchema(ss, reader_schema, error), errors::Unimplemented("Avro schema error: ", error)); + + + std::unique_ptr file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size())); + uint64 size; + OP_REQUIRES_OK(context, file->GetFileSize(&size)); + + if (length < 0) { + length = size - offset; + } + + avro::GenericDatum datum(reader_schema); + + std::unique_ptr stream(new AvroInputStream(file.get())); + std::unique_ptr> reader(new avro::DataFileReader(std::move(stream), reader_schema)); + + if (offset != 0) { + reader->sync(offset); + } + + #define BOOL_VALUE records.push_back(field.value()) + #define INT32_VALUE records.emplace_back(field.value()) + #define INT64_VALUE records.emplace_back(field.value()) + #define FLOAT_VALUE records.emplace_back(field.value()) + #define DOUBLE_VALUE records.emplace_back(field.value()) + #define STRING_VALUE records.emplace_back(field.value()) + #define BYTES_VALUE { \ + const std::vector& value = field.value>(); \ + string v; \ + if (value.size() > 0) { \ + v.resize(value.size()); \ + memcpy(&v[0], &value[0], value.size()); \ + } \ + records.emplace_back(v); \ + } + #define FIXED_VALUE { \ + const std::vector& value = field.value().value(); \ + string v; \ + if (value.size() > 0) { \ + v.resize(value.size()); \ + memcpy(&v[0], &value[0], value.size()); \ + } \ + records.emplace_back(v); \ + } + #define ENUM_VALUE records.emplace_back(field.value().symbol()) + + #define PROCESS_RECORD(TYPE, ATYPE, VALUE) { \ + std::vector records; \ + while (!reader->pastSync(offset + length) && reader->read(datum)) { \ + const avro::GenericRecord& record = datum.value(); \ + const avro::GenericDatum& field = record.field(column); \ + VALUE; \ + } \ + Tensor* output_tensor; \ + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({static_cast(records.size())}), &output_tensor)); \ + for (size_t i = 0; i < records.size(); i++) { \ + output_tensor->flat()(i) = std::move(records[i]); \ + } \ + } + switch(datum.value().field(column).type()) { + case avro::AVRO_BOOL: + PROCESS_RECORD(bool, bool, BOOL_VALUE); + break; + case avro::AVRO_INT: + PROCESS_RECORD(int32, int32_t, INT32_VALUE); + break; + case avro::AVRO_LONG: + PROCESS_RECORD(int64, int64_t, INT64_VALUE); + break; + case avro::AVRO_FLOAT: + PROCESS_RECORD(float, float, FLOAT_VALUE); + break; + case avro::AVRO_DOUBLE: + PROCESS_RECORD(double, double, DOUBLE_VALUE); + break; + case avro::AVRO_STRING: + PROCESS_RECORD(string, string, STRING_VALUE); + break; + case avro::AVRO_BYTES: + PROCESS_RECORD(string, string, BYTES_VALUE); + break; + case avro::AVRO_FIXED: + PROCESS_RECORD(string, string, FIXED_VALUE); + break; + case avro::AVRO_ENUM: + PROCESS_RECORD(string, string, ENUM_VALUE); + break; + default: + OP_REQUIRES(context, false, errors::InvalidArgument("unsupported data type: ", datum.value().field(column).type())); + } + + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ListAvroColumns").Device(DEVICE_CPU), + ListAvroColumnsOp); +REGISTER_KERNEL_BUILDER(Name("ReadAvro").Device(DEVICE_CPU), + ReadAvroOp); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/avro/ops/avro_ops.cc b/tensorflow_io/avro/ops/avro_ops.cc index acd720d3f..1c090dd83 100644 --- a/tensorflow_io/avro/ops/avro_ops.cc +++ b/tensorflow_io/avro/ops/avro_ops.cc @@ -19,27 +19,29 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("AvroInput") - .Input("source: string") - .Output("handle: variant") - .Attr("filters: list(string) = []") - .Attr("columns: list(string) = []") - .Attr("schema: string = ''") +REGISTER_OP("ListAvroColumns") + .Input("filename: string") + .Input("schema: string") + .Input("memory: string") + .Output("columns: string") + .Output("dtypes: string") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); -REGISTER_OP("AvroDataset") - .Input("input: T") - .Input("batch: int64") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("T: {string, variant} = DT_VARIANT") - .SetIsStateful() +REGISTER_OP("ReadAvro") + .Input("filename: string") + .Input("schema: string") + .Input("column: string") + .Input("memory: string") + .Input("offset: int64") + .Input("length: int64") + .Attr("dtype: type") + .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); + c->set_output(0, c->MakeShape({c->UnknownDim()})); return Status::OK(); }); diff --git a/tensorflow_io/avro/python/ops/avro_ops.py b/tensorflow_io/avro/python/ops/avro_ops.py index ce92389fb..9a5278210 100644 --- a/tensorflow_io/avro/python/ops/avro_ops.py +++ b/tensorflow_io/avro/python/ops/avro_ops.py @@ -18,53 +18,64 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.compat.v1 import data -from tensorflow_io.core.python.ops import core_ops as avro_ops +from tensorflow_io.core.python.ops import core_ops +from tensorflow_io.core.python.ops import data_ops -class AvroDataset(data.Dataset): +def list_avro_columns(filename, schema, **kwargs): + """list_avro_columns""" + if not tf.executing_eagerly(): + raise NotImplementedError("list_avro_columns only support eager mode") + memory = kwargs.get("memory", "") + columns, dtypes = core_ops.list_avro_columns( + filename, schema=schema, memory=memory) + entries = zip(tf.unstack(columns), tf.unstack(dtypes)) + return dict([(column.numpy().decode(), tf.TensorSpec( + tf.TensorShape([None]), + dtype.numpy().decode(), + column.numpy().decode())) for ( + column, dtype) in entries]) + +def read_avro(filename, schema, column, **kwargs): + """read_avro""" + memory = kwargs.get("memory", "") + offset = kwargs.get("offset", 0) + length = kwargs.get("length", -1) + return core_ops.read_avro( + filename, schema, column.name, memory=memory, + offset=offset, length=length, dtype=column.dtype) + +class AvroDataset(data_ops.BaseDataset): """A Avro Dataset that reads the avro file.""" - def __init__(self, filenames, columns, schema, dtypes=None, batch=None): + def __init__(self, filename, schema, column, **kwargs): """Create a `AvroDataset`. Args: - filenames: A 0-D or 1-D `tf.string` tensor containing one or more - filenames. - columns: A 0-D or 1-D `tf.int32` tensor containing the columns to extract. + filenames: A string containing one or more filename. schema: A string containing the avro schema. - dtypes: A tuple of `tf.DType` objects representing the types of the - columns returned. + column: A string containing the column to extract. """ - self._data_input = avro_ops.avro_input( - filenames, ["none", "gz"], columns=columns, schema=schema) - self._columns = columns - self._schema = schema - self._dtypes = dtypes - self._batch = 0 if batch is None else batch - super(AvroDataset, self).__init__() - - def _inputs(self): - return [] - - def _as_variant_tensor(self): - return avro_ops.avro_dataset( - self._data_input, - self._batch, - output_types=self.output_types, - output_shapes=self.output_shapes) - - @property - def output_classes(self): - return tuple([tf.Tensor for _ in self._columns]) + if not tf.executing_eagerly(): + dtype = kwargs.get("dtype") + else: + columns = list_avro_columns(filename, schema) + dtype = columns[column].dtype + shape = tf.TensorShape([None]) - @property - def output_shapes(self): - return tuple( - [tf.TensorShape([]) for _ in self._columns] - ) if self._batch is None else tuple( - [tf.TensorShape([None]) for _ in self._columns] - ) + filesize = tf.io.gfile.GFile(filename).size() + # capacity is the rough length for each split + capacity = kwargs.get("capacity", 65536) + entry_offset = list(range(0, filesize, capacity)) + entry_length = [min(capacity, filesize - offset) for offset in entry_offset] + dataset = data_ops.BaseDataset.from_tensor_slices( + ( + tf.constant(entry_offset, tf.int64), + tf.constant(entry_length, tf.int64) + ) + ).map(lambda offset, length: core_ops.read_avro( + filename, schema, column, memory="", + offset=offset, length=length, dtype=dtype)) + self._dataset = dataset - @property - def output_types(self): - return self._dtypes + super(AvroDataset, self).__init__( + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 25c62ec34..07bdc0b36 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -28,6 +28,7 @@ cc_library( name = "dataset_ops", srcs = [ "kernels/dataset_ops.h", + "kernels/stream.h", ], copts = tf_io_copts(), includes = [ diff --git a/tensorflow_io/core/kernels/stream.h b/tensorflow_io/core/kernels/stream.h new file mode 100644 index 000000000..e812babf4 --- /dev/null +++ b/tensorflow_io/core/kernels/stream.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace data { + +// Note: This SizedRandomAccessFile should only lives within Compute() +// of the kernel as buffer could be released by outside. +class SizedRandomAccessFile : public tensorflow::RandomAccessFile { + public: + SizedRandomAccessFile(Env* env, const string& filename, const void* optional_memory_buff, const size_t optional_memory_size) + : file_(nullptr) + , size_(optional_memory_size) + , buff_((const char *)(optional_memory_buff)) + , size_status_(Status::OK()) { + if (size_ == 0) { + size_status_ = env->GetFileSize(filename, &size_); + if (size_status_.ok()) { + size_status_ = env->NewRandomAccessFile(filename, &file_); + } + } + } + + virtual ~SizedRandomAccessFile() {} + Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (file_.get() != nullptr) { + return file_.get()->Read(offset, n, result, scratch); + } + size_t bytes_to_read = 0; + if (offset < size_) { + bytes_to_read = (offset + n < size_) ? n : (size_ - offset); + } + if (bytes_to_read > 0) { + memcpy(scratch, &buff_[offset], bytes_to_read); + } + *result = StringPiece(scratch, bytes_to_read); + if (bytes_to_read < n) { + return errors::OutOfRange("EOF reached"); + } + return Status::OK(); + } + Status GetFileSize(uint64* size) { + if (size_status_.ok()) { + *size = size_; + } + return size_status_; + } + private: + std::unique_ptr file_; + uint64 size_; + const char *buff_; + Status size_status_; +}; + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index bbd3b5d8b..40107037d 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -58,9 +58,8 @@ def _apply_fn(dataset): class BaseDataset(tf.compat.v2.data.Dataset): """A Base Dataset""" - def __init__(self, variant, batch, dtypes, shapes): + def __init__(self, variant, dtypes, shapes): """Create a Base Dataset.""" - self._batch = 0 if batch is None else batch self._dtypes = dtypes self._shapes = shapes super(BaseDataset, self).__init__(variant) @@ -93,4 +92,4 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._data_input, self._batch, output_types=self._dtypes, - output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + output_shapes=self._shapes), self._dtypes, self._shapes) diff --git a/tests/test_avro.py b/tests/test_avro.py index 6024b11ca..d134ee26c 100644 --- a/tests/test_avro.py +++ b/tests/test_avro.py @@ -19,74 +19,47 @@ from __future__ import print_function import os +import numpy as np +import pytest import tensorflow as tf tf.compat.v1.disable_eager_execution() - -from tensorflow import dtypes # pylint: disable=wrong-import-position -from tensorflow import errors # pylint: disable=wrong-import-position -from tensorflow import test # pylint: disable=wrong-import-position -from tensorflow.compat.v1 import data # pylint: disable=wrong-import-position - import tensorflow_io.avro as avro_io # pylint: disable=wrong-import-position -class AvroDatasetTest(test.TestCase): - """AvroDatasetTest""" - - def test_avro_dataset(self): - """Test case for AvroDataset.""" - # The test.bin was created from avro/lang/c++/examples/datafile.cc. - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_avro", "test.bin") - filename = "file://" + filename - - schema_filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_avro", "cpx.json") - with open(schema_filename, 'r') as f: - schema = f.read() - - columns = ['im', 're'] - output_types = (dtypes.float64, dtypes.float64) - num_repeats = 2 +def test_avro_dataset(): + """Test case for AvroDataset.""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename - dataset = avro_io.AvroDataset( - [filename], columns, schema, output_types).repeat(num_repeats) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() - with self.test_session() as sess: - sess.run(init_op) - for _ in range(num_repeats): - for i in range(100): - (im, re) = (i + 100, i * 100) - vv = sess.run(get_next) - self.assertAllClose((im, re), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + num_repeats = 2 + dataset = tf.compat.v2.data.Dataset.zip( + ( + avro_io.AvroDataset(filename, schema, "im", dtype=tf.float64), + avro_io.AvroDataset(filename, schema, "re", dtype=tf.float64) + )).repeat(num_repeats).apply(tf.data.experimental.unbatch()) - dataset = avro_io.AvroDataset( - [filename, filename], columns, schema, output_types, batch=3) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + init_op = iterator.initializer + get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - for ii in range(0, 198, 3): - i = ii % 100 - (im, re) = ( - [i + 100, ((i + 1) % 100) + 100, ((i + 2) % 100) + 100], - [i * 100, ((i + 1) % 100) * 100, ((i + 2) % 100) * 100]) + with tf.compat.v1.Session() as sess: + sess.run(init_op) + for _ in range(num_repeats): + for i in range(100): + (im, re) = (i + 100, i * 100) vv = sess.run(get_next) - self.assertAllClose((im, re), vv) - (im, re) = ([198, 199], [9800, 9900]) - vv = sess.run(get_next) - self.assertAllClose((im, re), vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + np.allclose((im, re), vv) + with pytest.raises(tf.errors.OutOfRangeError): + sess.run(get_next) if __name__ == "__main__": test.main() diff --git a/tests/test_avro_eager.py b/tests/test_avro_eager.py new file mode 100644 index 000000000..172697239 --- /dev/null +++ b/tests/test_avro_eager.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for AvroDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.avro as avro_io # pylint: disable=wrong-import-position + +def test_avro(): + """test_list_avro_columns.""" + # The test.bin was created from avro/lang/c++/examples/datafile.cc. + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "test.bin") + filename = "file://" + filename + + schema_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_avro", "cpx.json") + with open(schema_filename, 'r') as f: + schema = f.read() + + specs = avro_io.list_avro_columns(filename, schema) + assert specs["im"].dtype == tf.float64 + assert specs["re"].dtype == tf.float64 + + v0 = avro_io.read_avro(filename, schema, specs["im"]) + v1 = avro_io.read_avro(filename, schema, specs["re"]) + for i in range(100): + (im, re) = (i + 100, i * 100) + assert v0[i].numpy() == im + assert v1[i].numpy() == re + + for capacity in [10, 20, 50, 100, 1000, 2000]: + dataset = tf.compat.v2.data.Dataset.zip( + ( + avro_io.AvroDataset(filename, schema, "im", capacity=capacity), + avro_io.AvroDataset(filename, schema, "re", capacity=capacity) + ) + ).apply(tf.data.experimental.unbatch()) + i = 0 + for vv in dataset: + v0, v1 = vv + (im, re) = (i + 100, i * 100) + assert v0.numpy() == im + assert v1.numpy() == re + i += 1 + +if __name__ == "__main__": + test.main()