diff --git a/tensorflow_io/core/python/ops/data_ops.py b/tensorflow_io/core/python/ops/data_ops.py index bbd3b5d8b..40107037d 100644 --- a/tensorflow_io/core/python/ops/data_ops.py +++ b/tensorflow_io/core/python/ops/data_ops.py @@ -58,9 +58,8 @@ def _apply_fn(dataset): class BaseDataset(tf.compat.v2.data.Dataset): """A Base Dataset""" - def __init__(self, variant, batch, dtypes, shapes): + def __init__(self, variant, dtypes, shapes): """Create a Base Dataset.""" - self._batch = 0 if batch is None else batch self._dtypes = dtypes self._shapes = shapes super(BaseDataset, self).__init__(variant) @@ -93,4 +92,4 @@ def __init__(self, fn, data_input, batch, dtypes, shapes): self._data_input, self._batch, output_types=self._dtypes, - output_shapes=self._shapes), self._batch, self._dtypes, self._shapes) + output_shapes=self._shapes), self._dtypes, self._shapes) diff --git a/tensorflow_io/hdf5/BUILD b/tensorflow_io/hdf5/BUILD index 337a675b0..1a6ba42e1 100644 --- a/tensorflow_io/hdf5/BUILD +++ b/tensorflow_io/hdf5/BUILD @@ -10,7 +10,7 @@ load( cc_library( name = "hdf5_ops", srcs = [ - "kernels/hdf5_input.cc", + "kernels/hdf5_kernels.cc", "ops/hdf5_ops.cc", ], copts = tf_io_copts(), diff --git a/tensorflow_io/hdf5/__init__.py b/tensorflow_io/hdf5/__init__.py index 06da8a0a2..1fd7a0e69 100644 --- a/tensorflow_io/hdf5/__init__.py +++ b/tensorflow_io/hdf5/__init__.py @@ -15,6 +15,8 @@ """HDF5 Dataset. @@HDF5Dataset +@@list_hdf5_datasets +@@read_hdf5 """ from __future__ import absolute_import @@ -22,11 +24,15 @@ from __future__ import print_function from tensorflow_io.hdf5.python.ops.hdf5_ops import HDF5Dataset +from tensorflow_io.hdf5.python.ops.hdf5_ops import list_hdf5_datasets +from tensorflow_io.hdf5.python.ops.hdf5_ops import read_hdf5 from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "HDF5Dataset", + "list_hdf5_datasets", + "read_hdf5", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow_io/hdf5/kernels/hdf5_input.cc b/tensorflow_io/hdf5/kernels/hdf5_input.cc deleted file mode 100644 index c1a8a5769..000000000 --- a/tensorflow_io/hdf5/kernels/hdf5_input.cc +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernels/dataset_ops.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include -#include -#include - -namespace tensorflow { -namespace data { - -class HDF5InputStream{ -public: - explicit HDF5InputStream(io::InputStreamInterface* s, const std::vector& columns) - : columns_(columns) - , input_stream_(nullptr) - , buffered_stream_(nullptr) - , file_(nullptr) { - input_stream_ = dynamic_cast(s); - if (input_stream_ == nullptr) { - buffered_stream_.reset(new SizedRandomAccessBufferedStream(s)); - input_stream_ = buffered_stream_.get(); - } - } - ~HDF5InputStream() { - H5Fclose(file_image_); - file_.reset(nullptr); - buffered_stream_.reset(nullptr); - } - Status Open() { - uint64 size = 0; - TF_RETURN_IF_ERROR(input_stream_->GetFileSize(&size)); - buffer_.resize(size); - StringPiece result; - TF_RETURN_IF_ERROR(input_stream_->Read(0, size, &result, &buffer_[0])); - if (result.size() != size) { - return errors::InvalidArgument("unable to read enough data from file"); - } - file_image_ = H5LTopen_file_image((void *)buffer_.data(), size, H5LT_FILE_IMAGE_DONT_COPY | H5LT_FILE_IMAGE_DONT_RELEASE); - file_.reset(new H5::H5File()); - file_.get()->setId(file_image_); - // TODO: replace boilerplate - for (size_t i = 0; i < columns_.size(); i++) { - try { - H5::DataSet dataset = file_->openDataSet(H5std_string(columns_[i])); - H5::DataSpace dataspace = dataset.getSpace(); - int rank = dataspace.getSimpleExtentNdims(); - absl::InlinedVector dims(rank); - dataspace.getSimpleExtentDims(dims.data()); - dataset_.emplace_back(dataset); - dataspace_.emplace_back(dataspace); - dims_.emplace_back(dims); - - // Make sure first dimension remains the same - if (i == 0) { - count_ = dims[0]; - } else if (count_ != dims[0]) { - // Maybe we should fill in blanks? - return errors::InvalidArgument("dataset ", columns_[i], " has uneven count ", dims[0], " with others ", count_); - } - } catch(H5::FileIException e){ - return errors::InvalidArgument("unable to open dataset ", columns_[i], ": ", e.getCDetailMsg()); - } - } - return Status::OK(); - } - Status ReadRecord(IteratorContext* ctx, int64 record_to_read, int64* record_read, std::vector* out_tensors) { - if (index_ + record_to_read > count_) { - record_to_read = count_ - index_; - } - out_tensors->clear(); - if (record_to_read > 0) { - for (size_t i = 0; i < columns_.size(); i++) { - absl::InlinedVector dims = dims_[i]; - dims[0] = record_to_read; - H5::DataSpace memoryspace(dims.size(), dims.data()); - absl::InlinedVector start(dims_[i].size(), 0); - start[0] = index_; - dataspace_[i].selectHyperslab(H5S_SELECT_SET, dims.data(), start.data()); - - absl::InlinedVector shape_dims(dims_[i].size()); - for (size_t ii = 0; ii < dims_[i].size(); ii++) { - shape_dims[ii] = dims_[i][ii]; - } - shape_dims[0] = record_to_read; - TensorShape shape(shape_dims); - - - H5::DataType data_type = dataset_[i].getDataType(); - hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND); - if (H5Tequal(native_type, H5T_NATIVE_INT)) { - Tensor tensor(ctx->allocator({}), DT_INT32, shape); - dataset_[i].read(tensor.flat().data(), H5::PredType::NATIVE_INT, memoryspace, dataspace_[i]); - out_tensors->emplace_back(std::move(tensor)); - } else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) { - Tensor tensor(ctx->allocator({}), DT_UINT32, shape); - dataset_[i].read(tensor.flat().data(), H5::PredType::NATIVE_UINT32, memoryspace, dataspace_[i]); - out_tensors->emplace_back(std::move(tensor)); - }else if (H5Tequal(native_type, H5T_NATIVE_LONG)) { - Tensor tensor(ctx->allocator({}), DT_INT64, shape); - dataset_[i].read(tensor.flat().data(), H5::PredType::NATIVE_LONG, memoryspace, dataspace_[i]); - out_tensors->emplace_back(std::move(tensor)); - } else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) { - Tensor tensor(ctx->allocator({}), DT_FLOAT, shape); - dataset_[i].read(tensor.flat().data(), H5::PredType::NATIVE_FLOAT, memoryspace, dataspace_[i]); - out_tensors->emplace_back(std::move(tensor)); - } else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) { - Tensor tensor(ctx->allocator({}), DT_DOUBLE, shape); - dataset_[i].read(tensor.flat().data(), H5::PredType::NATIVE_DOUBLE, memoryspace, dataspace_[i]); - out_tensors->emplace_back(std::move(tensor)); - } else { - return errors::Unimplemented("data type not supported yet: ", dataset_[i].getTypeClass()); - } - } - *record_read = record_to_read; - index_ += record_to_read; - } - return Status::OK(); - } -private: - std::vector columns_; - SizedRandomAccessInputStreamInterface* input_stream_; - std::unique_ptr buffered_stream_; - string buffer_; - std::unique_ptr file_; - hid_t file_image_; - std::vector dataset_; - std::vector dataspace_; - std::vector> dims_; - int64 count_ = -1; - int64 index_ = 0; -}; - -class HDF5Input: public FileInput { - public: - Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { - if (state.get() == nullptr) { - state.reset(new HDF5InputStream(s, columns())); - TF_RETURN_IF_ERROR(state.get()->Open()); - } - return state.get()->ReadRecord(ctx, record_to_read, record_read, out_tensors); - } - Status FromStream(io::InputStreamInterface* s) override { - return Status::OK(); - } - void EncodeAttributes(VariantTensorData* data) const override { - } - bool DecodeAttributes(const VariantTensorData& data) override { - return true; - } - protected: -}; - -REGISTER_UNARY_VARIANT_DECODE_FUNCTION(HDF5Input, "tensorflow::data::HDF5Input"); - -REGISTER_KERNEL_BUILDER(Name("HDF5Input").Device(DEVICE_CPU), - FileInputOp); -REGISTER_KERNEL_BUILDER(Name("HDF5Dataset").Device(DEVICE_CPU), - FileInputDatasetOp); -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/hdf5/kernels/hdf5_kernels.cc b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc new file mode 100644 index 000000000..ac8ab4cba --- /dev/null +++ b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc @@ -0,0 +1,324 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" + +#include +#include +#include + +namespace tensorflow { +namespace data { +namespace { + +class HDF5FileImage { + public: + HDF5FileImage(Env* env, const string& filename, const string& optional_memory) + : filename_(filename) + , optional_memory_(optional_memory) + , file_(nullptr) { + if (optional_memory.size() != 0) { + file_image_ = H5LTopen_file_image((void *)optional_memory_.data(), optional_memory_.size(), H5LT_FILE_IMAGE_DONT_COPY | H5LT_FILE_IMAGE_DONT_RELEASE); + file_.reset(new H5::H5File()); + file_.get()->setId(file_image_); + } else if (filename.find("://") == string::npos) { + file_.reset(new H5::H5File(filename, H5F_ACC_RDONLY)); + } else { + uint64 size = 0; + Status status = env->GetFileSize(filename, &size); + if (status.ok()) { + std::unique_ptr file; + status = env->NewRandomAccessFile(filename, &file); + if (status.ok()) { + StringPiece result; + buffer_memory_.resize(size); + status = file->Read(0, size, &result, &buffer_memory_[0]); + if (status.ok()) { + file_image_ = H5LTopen_file_image((void *)buffer_memory_.data(), buffer_memory_.size(), H5LT_FILE_IMAGE_DONT_COPY | H5LT_FILE_IMAGE_DONT_RELEASE); + file_.reset(new H5::H5File()); + file_.get()->setId(file_image_); + } + } + } + } + } + + virtual ~HDF5FileImage() { + if (file_image_ != 0) { + H5Fclose(file_image_); + } + file_.reset(nullptr); + } + + H5::H5File *GetFile() const { + return file_.get(); + } + + + private: + string filename_; + const string& optional_memory_; + string buffer_memory_; + std::unique_ptr file_; + hid_t file_image_ = 0; +}; + +class HDF5Iterate { +public: + HDF5Iterate(haddr_t root) + : parent_(root) { + groups_[root] = ""; + } + ~HDF5Iterate() {} + + static herr_t Iterate(hid_t loc_id, const char *name, const H5L_info_t *info, void *operator_data) { + HDF5Iterate *p = (HDF5Iterate *)operator_data; + + H5O_info_t iteminfo; + herr_t err = H5Oget_info_by_name (loc_id, name, &iteminfo, H5P_DEFAULT); + + switch (iteminfo.type) { + case H5O_TYPE_GROUP: + if (p->groups_.find(iteminfo.addr) == p->groups_.end()) { + haddr_t parent = p->parent_; + p->groups_[iteminfo.addr] = p->groups_[parent] + "/" + name; + p->parent_ = iteminfo.addr; + err = H5Literate_by_name(loc_id, name, H5_INDEX_NAME, H5_ITER_NATIVE, NULL, HDF5Iterate::Iterate, operator_data, H5P_DEFAULT); + p->parent_ = parent; + } + break; + case H5O_TYPE_DATASET: { + string dataset = p->groups_[p->parent_] + "/" + name; + p->datasets_.emplace_back(dataset); + } + break; + case H5O_TYPE_NAMED_DATATYPE: + break; + default: + break; + } + return err; + } + + std::vector datasets_; + std::unordered_map groups_; + haddr_t parent_; +}; + +class ListHDF5DatasetsOp : public OpKernel { + public: + explicit ListHDF5DatasetsOp(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string filename = filename_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(1); + const string& memory = memory_tensor.scalar()(); + + HDF5FileImage file_image(env_, filename, memory); + H5::H5File *file = file_image.GetFile(); + OP_REQUIRES(context, file != nullptr, errors::InvalidArgument("unable to open hdf5 file: ", filename)); + + H5O_info_t info; + file->getObjinfo(info); + + HDF5Iterate data(info.addr); + + herr_t err = H5Literate (file->getId(), H5_INDEX_NAME, H5_ITER_NATIVE, NULL, HDF5Iterate::Iterate, (void *)&data); + + std::vector datasets; + std::vector dtypes; + std::vector> shapes; + datasets.reserve(data.datasets_.size()); + dtypes.reserve(data.datasets_.size()); + shapes.reserve(data.datasets_.size()); + int maxrank = 0; + for (size_t i = 0; i < data.datasets_.size(); i++) { + string dataset = data.datasets_[i]; + string dtype = ""; + H5::DataSet data_set = file->openDataSet(dataset); + + H5::DataSpace data_space = data_set.getSpace(); + int rank = data_space.getSimpleExtentNdims(); + absl::InlinedVector dims(rank); + data_space.getSimpleExtentDims(dims.data()); + + maxrank = rank < maxrank ? maxrank : rank; + + H5::DataType data_type = data_set.getDataType(); + hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND); + if (H5Tequal(native_type, H5T_NATIVE_INT)) { + dtype = "int32"; + } else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) { + dtype = "uint32"; + } else if (H5Tequal(native_type, H5T_NATIVE_LONG)) { + dtype = "int64"; + } else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) { + dtype = "float"; + } else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) { + dtype = "double"; + } else { + continue; + } + datasets.emplace_back(dataset); + dtypes.emplace_back(dtype); + shapes.emplace_back(dims); + } + + TensorShape output_shape = filename_tensor.shape(); + output_shape.AddDim(datasets.size()); + + Tensor* datasets_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &datasets_tensor)); + Tensor* dtypes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor)); + + for (size_t i = 0; i < datasets.size(); i++) { + datasets_tensor->flat()(i) = datasets[i]; + dtypes_tensor->flat()(i) = dtypes[i]; + } + + output_shape.AddDim(maxrank); + + Tensor* shapes_tensor; + OP_REQUIRES_OK(context, context->allocate_output(2, output_shape, &shapes_tensor)); + for (size_t i = 0; i < shapes.size(); i++) { + for (size_t j = 0; j < shapes[i].size(); j++) { + shapes_tensor->flat()(i * maxrank + j) = shapes[i][j]; + } + for (size_t j = shapes[i].size(); j < maxrank; j++) { + shapes_tensor->flat()(i * maxrank + j) = -1; + } + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +class ReadHDF5Op : public OpKernel { + public: + explicit ReadHDF5Op(OpKernelConstruction* context) : OpKernel(context) { + env_ = context->env(); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filename_tensor = context->input(0); + const string& filename = filename_tensor.scalar()(); + + const Tensor& dataset_tensor = context->input(1); + const string& dataset = dataset_tensor.scalar()(); + + const Tensor& memory_tensor = context->input(2); + const string& memory = memory_tensor.scalar()(); + + const Tensor& start_tensor = context->input(3); + + const Tensor& stop_tensor = context->input(4); + + HDF5FileImage file_image(env_, filename, memory); + H5::H5File *file = file_image.GetFile(); + OP_REQUIRES(context, file != nullptr, errors::InvalidArgument("unable to open hdf5 file: ", filename)); + try { + H5::DataSet data_set = file->openDataSet(dataset); + + H5::DataSpace data_space = data_set.getSpace(); + int rank = data_space.getSimpleExtentNdims(); + absl::InlinedVector dims(rank); + data_space.getSimpleExtentDims(dims.data()); + + std::vector start(dims.size(), 0); + std::vector stop(dims.size(), -1); + for (size_t i = 0; i < start_tensor.NumElements(); i++) { + start[i] = start_tensor.flat()(i); + } + for (size_t i = 0; i < stop_tensor.NumElements(); i++) { + stop[i] = stop_tensor.flat()(i); + } + for (size_t i = 0; i < stop.size(); i++) { + if (stop[i] < 0) { + stop[i] = dims[i]; + } + } + + // Find the border of the dims start + absl::InlinedVector dims_start(dims.size(), 0); + for (int64 i = 0; i < dims_start.size(); i++) { + dims_start[i] = (start[i] < dims[i]) ? (start[i]) : (dims[i]); + } + // Find the border of the dims final + absl::InlinedVector dims_final(dims); + for (int64 i = 0; i < dims_final.size(); i++) { + dims_final[i] = (stop[i] < dims[i]) ? (stop[i]) : (dims[i]); + } + // Find the area of the dims = [start...final] + absl::InlinedVector dims_shape(dims.size()); + for (int64 i = 0; i < dims_shape.size(); i++) { + dims[i] = (dims_final[i] > dims_start[i]) ? (dims_final[i] - dims_start[i]) : 0; + dims_shape[i] = dims[i]; + } + + TensorShape output_shape(dims_shape); + + Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); + + // Return with zero elements + for (int64 i = 0; i < dims_shape.size(); i++) { + if (dims_shape[0] == 0) { + return; + } + } + + H5::DataSpace memory_space(dims.size(), dims.data()); + + data_space.selectHyperslab(H5S_SELECT_SET, dims.data(), dims_start.data()); + + H5::DataType data_type = data_set.getDataType(); + hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND); + if (H5Tequal(native_type, H5T_NATIVE_INT)) { + data_set.read(output_tensor->flat().data(), H5::PredType::NATIVE_INT, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) { + data_set.read(output_tensor->flat().data(), H5::PredType::NATIVE_UINT32, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_LONG)) { + data_set.read(output_tensor->flat().data(), H5::PredType::NATIVE_LONG, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) { + data_set.read(output_tensor->flat().data(), H5::PredType::NATIVE_FLOAT, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) { + data_set.read(output_tensor->flat().data(), H5::PredType::NATIVE_DOUBLE, memory_space, data_space); + } else { + OP_REQUIRES(context, false, errors::Unimplemented("data type not supported yet: ", data_set.getTypeClass())); + } + } catch(H5::FileIException e){ + OP_REQUIRES(context, false, errors::InvalidArgument("unable to open dataset", e.getCDetailMsg())); + } + } + private: + mutex mu_; + Env* env_ GUARDED_BY(mu_); +}; + +REGISTER_KERNEL_BUILDER(Name("ListHDF5Datasets").Device(DEVICE_CPU), + ListHDF5DatasetsOp); +REGISTER_KERNEL_BUILDER(Name("ReadHDF5").Device(DEVICE_CPU), + ReadHDF5Op); + + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/hdf5/ops/hdf5_ops.cc b/tensorflow_io/hdf5/ops/hdf5_ops.cc index 69def2fd3..c3229697d 100644 --- a/tensorflow_io/hdf5/ops/hdf5_ops.cc +++ b/tensorflow_io/hdf5/ops/hdf5_ops.cc @@ -19,28 +19,30 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("HDF5Input") - .Input("source: string") - .Output("handle: variant") - .Attr("filters: list(string) = []") - .Attr("columns: list(string) = []") - .Attr("schema: string = ''") +REGISTER_OP("ListHDF5Datasets") + .Input("filename: string") + .Input("memory: string") + .Output("datasets: string") + .Output("dtypes: string") + .Output("shapes: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({c->UnknownDim()})); + c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()})); return Status::OK(); }); -REGISTER_OP("HDF5Dataset") - .Input("input: T") - .Input("batch: int64") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("T: {string, variant} = DT_VARIANT") - .SetIsStateful() +REGISTER_OP("ReadHDF5") + .Input("filename: string") + .Input("dataset: string") + .Input("memory: string") + .Input("start: int64") + .Input("stop: int64") + .Attr("dtype: type") + .Output("output: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({})); - return Status::OK(); - }); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/hdf5/python/ops/hdf5_ops.py b/tensorflow_io/hdf5/python/ops/hdf5_ops.py index baba4b464..6c20777af 100644 --- a/tensorflow_io/hdf5/python/ops/hdf5_ops.py +++ b/tensorflow_io/hdf5/python/ops/hdf5_ops.py @@ -18,48 +18,69 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.compat.v1 import data -from tensorflow_io.core.python.ops import core_ops as hdf5_ops +from tensorflow_io.core.python.ops import core_ops +from tensorflow_io.core.python.ops import data_ops -class HDF5Dataset(data.Dataset): +def list_hdf5_datasets(filename, **kwargs): + """list_hdf5_datasets""" + if not tf.executing_eagerly(): + raise NotImplementedError("list_hdf5_datasets only support eager mode") + memory = kwargs.get("memory", "") + datasets, dtypes, shapes = core_ops.list_hdf5_datasets( + filename, memory=memory) + entries = zip(tf.unstack(datasets), tf.unstack(dtypes), tf.unstack(shapes)) + entries = [ + (dataset, dtype, tf.boolean_mask( + shape, tf.math.greater_equal(shape, 0))) for ( + dataset, dtype, shape) in entries] + return dict([(dataset.numpy().decode(), tf.TensorSpec( + shape.numpy(), dtype.numpy().decode(), dataset.numpy().decode())) for ( + dataset, dtype, shape) in entries]) + +def read_hdf5(filename, dataset, **kwargs): + """read_hdf5""" + memory = kwargs.get("memory", "") + start = kwargs.get("start", 0) + stop = kwargs.get("stop", None) + if stop is None and dataset.shape[0] is not None: + stop = dataset.shape[0] - start + if stop is None: + stop = -1 + return core_ops.read_hdf5( + filename, dataset.name, memory=memory, + start=start, stop=stop, dtype=dataset.dtype) + +class HDF5Dataset(data_ops.BaseDataset): """A HDF5 Dataset that reads the hdf5 file.""" - def __init__(self, filenames, columns, dtypes=None, shapes=None, batch=None): + def __init__(self, filename, dataset, **kwargs): """Create a `HDF5Dataset`. Args: - filenames: A 0-D or 1-D `tf.string` tensor containing one or more - filenames. - columns: A 0-D or 1-D `tf.int32` tensor containing the columns to extract. - dtypes: A tuple of `tf.DType` objects representing the types of the - columns returned. + filename: A string of th hdf5 filename. + dataset: A string of the dataset name. """ - self._data_input = hdf5_ops.hdf5_input( - filenames, ["none", "gz"], columns=columns) - self._columns = columns - self._dtypes = dtypes - self._shapes = shapes - self._batch = 0 if batch is None else batch - super(HDF5Dataset, self).__init__() - - def _inputs(self): - return [] - - def _as_variant_tensor(self): - return hdf5_ops.hdf5_dataset( - self._data_input, - self._batch, - output_types=self.output_types, - output_shapes=self.output_shapes) - - @property - def output_classes(self): - return tuple([tf.Tensor for _ in self._columns]) - - @property - def output_shapes(self): - return tuple([tf.TensorShape([]) for _ in self._shapes]) + if not tf.executing_eagerly(): + start = kwargs.get("start") + stop = kwargs.get("stop") + dtype = kwargs.get("dtype") + shape = kwargs.get("shape") + else: + datasets = list_hdf5_datasets(filename) + start = 0 + stop = datasets[dataset].shape[0] + dtype = datasets[dataset].dtype + shape = tf.TensorShape( + [None if i == 0 else e for i, e in enumerate( + datasets[dataset].shape.as_list())]) - @property - def output_types(self): - return tuple([dtype for dtype in self._dtypes]) + # capacity is the rough count for each chunk in dataset + capacity = kwargs.get("capacity", 65536) + entry_start = list(range(start, stop, capacity)) + entry_stop = entry_start[1:] + [stop] + self._dataset = data_ops.BaseDataset.from_tensor_slices( + (tf.constant(entry_start, tf.int64), tf.constant(entry_stop, tf.int64)) + ).map(lambda start, stop: core_ops.read_hdf5( + filename, dataset, memory="", start=start, stop=stop, dtype=dtype)) + super(HDF5Dataset, self).__init__( + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access diff --git a/tests/test_hdf5.py b/tests/test_hdf5.py index 7f68e2beb..0ea3dbcb1 100644 --- a/tests/test_hdf5.py +++ b/tests/test_hdf5.py @@ -40,10 +40,12 @@ def test_hdf5_invalid_dataset(self): "test_hdf5", "tdset.h5") filename = "file://" + filename dataset = hdf5_io.HDF5Dataset( - [filename], - ['/invalid', '/invalid2'], - [dtypes.int32, dtypes.int32], - [(1, 20), (1, 30)]) + filename, + '/invalid', + dtype=dtypes.int32, + shape=tf.TensorShape([1, 20]), + start=0, + stop=10) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() @@ -51,7 +53,7 @@ def test_hdf5_invalid_dataset(self): with self.test_session() as sess: sess.run(init_op) with self.assertRaisesRegexp( - errors.InvalidArgumentError, "unable to open dataset /invalid"): + errors.InvalidArgumentError, "unable to open dataset"): sess.run(get_next) def test_hdf5_dataset_int32(self): @@ -60,12 +62,13 @@ def test_hdf5_dataset_int32(self): os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tdset.h5") filename = "file://" + filename - columns = ['/dset1'] - output_types = [dtypes.int32] - output_shapes = [(1, 20)] + column = '/dset1' + dtype = dtypes.int32 + shape = tf.TensorShape([None, 20]) dataset = hdf5_io.HDF5Dataset( - [filename], columns, output_types, output_shapes) + filename, column, start=0, stop=10, dtype=dtype, shape=shape).apply( + tf.data.experimental.unbatch()) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() @@ -74,7 +77,7 @@ def test_hdf5_dataset_int32(self): for i in range(10): v0 = list([np.asarray([v for v in range(i, i + 20)])]) vv = sess.run(get_next) - self.assertAllEqual(v0, vv) + self.assertAllEqual(v0, [vv]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -89,12 +92,13 @@ def test_hdf5_dataset_int32_zlib(self): os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "compressed_h5.h5") filename = "file://" + filename - columns = ['/dset1'] - output_types = [dtypes.int32] - output_shapes = [(1, 20)] + column = '/dset1' + dtype = dtypes.int32 + shape = tf.TensorShape([None, 20]) dataset = hdf5_io.HDF5Dataset( - [filename], columns, output_types, output_shapes) + filename, column, start=0, stop=10, dtype=dtype, shape=shape).apply( + tf.data.experimental.unbatch()) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() @@ -103,7 +107,7 @@ def test_hdf5_dataset_int32_zlib(self): for i in range(10): v0 = list([np.asarray([v for v in range(i, i + 20)])]) vv = sess.run(get_next) - self.assertAllEqual(v0, vv) + self.assertAllEqual(v0, [vv]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -114,46 +118,23 @@ def test_hdf5_dataset(self): os.path.dirname(os.path.abspath(__file__)), "test_hdf5", "tdset.h5") filename = "file://" + filename - columns = ['/dset2'] - output_types = [dtypes.float64] - output_shapes = [(1, 20)] + column = '/dset2' + dtype = dtypes.float64 + shape = tf.TensorShape([None, 20]) dataset = hdf5_io.HDF5Dataset( - [filename], columns, output_types, output_shapes, batch=1) + filename, column, start=0, stop=30, dtype=dtype, shape=shape).apply( + tf.data.experimental.unbatch()) iterator = data.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) for i in range(30): - v0 = list( - [np.asarray([[i + 1e-04 * v for v in range(20)]], - dtype=np.float64)]) + v0 = np.asarray([[i + 1e-04 * v for v in range(20)]], + dtype=np.float64) vv = sess.run(get_next) - self.assertAllEqual(v0, vv) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_hdf5_dataset_binary(self): - """Test case for HDF5Dataset.""" - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "test_hdf5", "tbinary.h5") - filename = "file://" + filename - columns = ['integer', 'float', 'double'] - output_types = [dtypes.int32, dtypes.float32, dtypes.float64] - output_shapes = [(1), (1), (1)] - - dataset = hdf5_io.HDF5Dataset( - [filename], columns, output_types, output_shapes) - iterator = data.make_initializable_iterator(dataset) - init_op = iterator.initializer - get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - for i in range(1, 7): - vv = sess.run(get_next) - self.assertAllEqual((i, np.float32(i), np.float64(i)), vv) + self.assertAllEqual(v0, [vv]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tests/test_hdf5/h5ex_g_traverse.h5 b/tests/test_hdf5/h5ex_g_traverse.h5 new file mode 100644 index 000000000..d8267b14e Binary files /dev/null and b/tests/test_hdf5/h5ex_g_traverse.h5 differ diff --git a/tests/test_hdf5_eager.py b/tests/test_hdf5_eager.py new file mode 100644 index 000000000..50c74bd45 --- /dev/null +++ b/tests/test_hdf5_eager.py @@ -0,0 +1,74 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for HDF5.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +import tensorflow as tf +if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): + tf.compat.v1.enable_eager_execution() +import tensorflow_io.hdf5 as hdf5_io # pylint: disable=wrong-import-position + +def test_hdf5_list_dataset(): + """test_hdf5_list_dataset""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_hdf5", "h5ex_g_traverse.h5") + + # Without file:// file will be opened directly, otherwise + # file will be opened in memory. + for filename in [filename, "file://" + filename]: + specs = hdf5_io.list_hdf5_datasets(filename) + assert specs['/group1/dset1'].dtype == tf.int32 + assert specs['/group1/dset1'].shape == tf.TensorShape([1, 1]) + assert specs['/group1/group3/dset2'].dtype == tf.int32 + assert specs['/group1/group3/dset2'].shape == tf.TensorShape([1, 1]) + +def test_hdf5_read_dataset(): + """test_hdf5_list_dataset""" + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_hdf5", "tdset.h5") + + for filename in [filename, "file://" + filename]: + specs = hdf5_io.list_hdf5_datasets(filename) + assert specs['/dset1'].dtype == tf.int32 + assert specs['/dset1'].shape == tf.TensorShape([10, 20]) + assert specs['/dset2'].dtype == tf.float64 + assert specs['/dset2'].shape == tf.TensorShape([30, 20]) + + p1 = hdf5_io.read_hdf5(filename, specs['/dset1']) + assert p1.dtype == tf.int32 + assert p1.shape == tf.TensorShape([10, 20]) + for i in range(10): + vv = list([np.asarray([v for v in range(i, i + 20)])]) + assert np.all(p1[i].numpy() == vv) + + dataset = hdf5_io.HDF5Dataset(filename, '/dset1').apply( + tf.data.experimental.unbatch()) + i = 0 + for p in dataset: + vv = list([np.asarray([v for v in range(i, i + 20)])]) + assert np.all(p.numpy() == vv) + i += 1 + + +if __name__ == "__main__": + test.main()