From ed7e1817ff0725df9b7bc0031c14bfc3e647d5a3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 23 Aug 2019 19:34:18 +0000 Subject: [PATCH] Add tfio.IOTensor.from_hdf5 support HDF5 file is a widely used format. It normally stores data into each named `dataset` which is a block of array with shape. It is not exactly columnar as different `dataset` in HDF5 could have different shapes unrelated to each other. From that standpoint it is more like a storage for collections of tensors (where each `dataset` represent one `tensor`). HDF5 does allow slicing and indexing. In fact, the slicing and indexing in HDF5 are much more powerful than many other formats. This PR adds tfio.IOTensor.from_hdf5. It treats HDF5 as a collection of BaseIOTensor which could be further used for slicing and indexing. Note the `collection` here essentially is just a dictionary of key with BaseIOTensor as the value. It is different from Columnar IOTensor's case like Parquet or Avro. Signed-off-by: Yong Tang --- .../core/python/ops/hdf5_io_tensor_ops.py | 51 ++++++ tensorflow_io/core/python/ops/io_tensor.py | 18 +++ .../core/python/ops/io_tensor_ops.py | 42 +++++ tensorflow_io/hdf5/kernels/hdf5_kernels.cc | 153 ++++++++++++++++++ tensorflow_io/hdf5/ops/hdf5_ops.cc | 40 +++++ tensorflow_io/hdf5/python/ops/hdf5_ops.py | 8 + tests/test_hdf5_eager.py | 29 ++-- 7 files changed, 325 insertions(+), 16 deletions(-) create mode 100644 tensorflow_io/core/python/ops/hdf5_io_tensor_ops.py diff --git a/tensorflow_io/core/python/ops/hdf5_io_tensor_ops.py b/tensorflow_io/core/python/ops/hdf5_io_tensor_ops.py new file mode 100644 index 000000000..4c5b9ed34 --- /dev/null +++ b/tensorflow_io/core/python/ops/hdf5_io_tensor_ops.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""HDF5IOTensor""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import io_tensor_ops +from tensorflow_io.core.python.ops import core_ops + +class HDF5IOTensor(io_tensor_ops._CollectionIOTensor): # pylint: disable=protected-access + """HDF5IOTensor""" + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + filename, + internal=False): + with tf.name_scope("HDF5IOTensor") as scope: + resource, columns = core_ops.hdf5_indexable_init( + filename, + container=scope, + shared_name="%s/%s" % (filename, uuid.uuid4().hex)) + columns = [column.decode() for column in columns.numpy().tolist()] + spec = [] + for column in columns: + shape, dtype = core_ops.hdf5_indexable_spec(resource, column) + shape = tf.TensorShape(shape) + dtype = tf.as_dtype(dtype.numpy()) + spec.append(tf.TensorSpec(shape, dtype, column)) + spec = tuple(spec) + super(HDF5IOTensor, self).__init__( + spec, columns, + resource, core_ops.hdf5_indexable_get_item, + internal=internal) diff --git a/tensorflow_io/core/python/ops/io_tensor.py b/tensorflow_io/core/python/ops/io_tensor.py index 7884becc3..e452ed40c 100644 --- a/tensorflow_io/core/python/ops/io_tensor.py +++ b/tensorflow_io/core/python/ops/io_tensor.py @@ -21,6 +21,7 @@ from tensorflow_io.core.python.ops import io_tensor_ops from tensorflow_io.core.python.ops import audio_io_tensor_ops from tensorflow_io.core.python.ops import json_io_tensor_ops +from tensorflow_io.core.python.ops import hdf5_io_tensor_ops from tensorflow_io.core.python.ops import kafka_io_tensor_ops from tensorflow_io.core.python.ops import lmdb_io_tensor_ops from tensorflow_io.core.python.ops import prometheus_io_tensor_ops @@ -346,3 +347,20 @@ def from_lmdb(cls, """ with tf.name_scope(kwargs.get("name", "IOFromLMDB")): return lmdb_io_tensor_ops.LMDBIOTensor(filename, internal=True) + + @classmethod + def from_hdf5(cls, + filename, + **kwargs): + """Creates an `IOTensor` from an hdf5 file. + + Args: + filename: A string, the filename of an hdf5 file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IOTensor`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromHDF5")): + return hdf5_io_tensor_ops.HDF5IOTensor(filename, internal=True) diff --git a/tensorflow_io/core/python/ops/io_tensor_ops.py b/tensorflow_io/core/python/ops/io_tensor_ops.py index 9a178d0b7..c4e9b450d 100644 --- a/tensorflow_io/core/python/ops/io_tensor_ops.py +++ b/tensorflow_io/core/python/ops/io_tensor_ops.py @@ -316,6 +316,48 @@ def __call__(self, column): spec, self._resource, self._function, component=column, internal=True) +class _CollectionIOTensor(_IOTensor): + """_CollectionIOTensor + + `CollectionIOTensor` is differnt from `TableIOTensor` in that each + component could have different shapes. While additional table-wide + operations are planned to be supported for `TableIOTensor` so that + the same operations could be applied to every column, there is no plan + to support the same in `CollectionIOTensor`. In other words, + `CollectionIOTensor` is only a dictionary with values consisting + of `BaseIOTensor`. + """ + + def __init__(self, + spec, + keys, + resource, + function, + internal=False): + self._keys = keys + self._resource = resource + self._function = function + super(_CollectionIOTensor, self).__init__( + spec, keys, internal=internal) + + #============================================================================= + # Accessors + #============================================================================= + + @property + def keys(self): + """The names of columns""" + return self._keys + + def __call__(self, key): + """Return a BaseIOTensor with key named `key`""" + key_index = self.keys.index( + next(e for e in self.keys if e == key)) + spec = tf.nest.flatten(self.spec)[key_index] + return BaseIOTensor( + spec, self._resource, self._function, + component=key, internal=True) + class _SeriesIOTensor(_IOTensor): """_SeriesIOTensor""" diff --git a/tensorflow_io/hdf5/kernels/hdf5_kernels.cc b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc index ac8ab4cba..5919d1437 100644 --- a/tensorflow_io/hdf5/kernels/hdf5_kernels.cc +++ b/tensorflow_io/hdf5/kernels/hdf5_kernels.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/stream.h" #include #include @@ -320,5 +322,156 @@ REGISTER_KERNEL_BUILDER(Name("ReadHDF5").Device(DEVICE_CPU), } // namespace + + +class HDF5Indexable : public IOIndexableInterface { + public: + HDF5Indexable(Env* env) + : env_(env) {} + + ~HDF5Indexable() {} + Status Init(const std::vector& input, const std::vector& metadata, const void* memory_data, const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size)); + TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_)); + + file_image_.reset(new HDF5FileImage(env_, filename, "")); + H5::H5File *file = file_image_->GetFile(); + if (file == nullptr) { + return errors::InvalidArgument("unable to open hdf5 file: ", filename); + } + + H5O_info_t info; + file->getObjinfo(info); + HDF5Iterate data(info.addr); + herr_t err = H5Literate(file->getId(), H5_INDEX_NAME, H5_ITER_NATIVE, NULL, HDF5Iterate::Iterate, (void *)&data); + for (size_t i = 0; i < data.datasets_.size(); i++) { + columns_.emplace_back(data.datasets_[i]); + columns_index_[data.datasets_[i]] = i; + } + + for (size_t i = 0; i < columns_.size(); i++) { + ::tensorflow::DataType dtype; + string dataset = columns_[i]; + H5::DataSet data_set = file->openDataSet(dataset); + + H5::DataSpace data_space = data_set.getSpace(); + int rank = data_space.getSimpleExtentNdims(); + absl::InlinedVector dims(rank); + data_space.getSimpleExtentDims(dims.data()); + + H5::DataType data_type = data_set.getDataType(); + hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND); + if (H5Tequal(native_type, H5T_NATIVE_INT)) { + dtype = DT_INT32; + } else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) { + dtype = DT_UINT32; + } else if (H5Tequal(native_type, H5T_NATIVE_LONG)) { + dtype = DT_INT64; + } else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) { + dtype = DT_FLOAT; + } else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) { + dtype = DT_DOUBLE; + } else { + return errors::InvalidArgument("unsupported data type: ", native_type); + } + dtypes_.emplace_back(dtype); + absl::InlinedVector shape_dims(rank); + for (int r = 0; r < rank; r++) { + shape_dims[r] = dims[r]; + } + shapes_.emplace_back(TensorShape(shape_dims)); + } + return Status::OK(); + } + Status Component(Tensor* component) override { + *component = Tensor(DT_STRING, TensorShape({static_cast(columns_.size())})); + for (size_t i = 0; i < columns_.size(); i++) { + component->flat()(i) = columns_[i]; + } + return Status::OK(); + } + Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override { + const int64 column_index = columns_index_[component.scalar()()]; + *shape = shapes_[column_index]; + *dtype = dtypes_[column_index]; + return Status::OK(); + } + + Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override { + if (step != 1) { + return errors::InvalidArgument("step ", step, " is not supported"); + } + const string& column = component.scalar()(); + + H5::H5File *file = file_image_->GetFile(); + try { + H5::DataSet data_set = file->openDataSet(column); + H5::DataSpace data_space = data_set.getSpace(); + + int rank = data_space.getSimpleExtentNdims(); + absl::InlinedVector dims(rank); + data_space.getSimpleExtentDims(dims.data()); + + if (start > dims[0] || stop > dims[0]) { + return errors::InvalidArgument("dataset ", column, " selection is out of boundary"); + } + // Find the border of the dims start and dims + absl::InlinedVector dims_start(dims.size(), 0); + dims_start[0] = start; + dims[0] = stop - start; + + H5::DataSpace memory_space(dims.size(), dims.data()); + + data_space.selectHyperslab(H5S_SELECT_SET, dims.data(), dims_start.data()); + + H5::DataType data_type = data_set.getDataType(); + hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND); + if (H5Tequal(native_type, H5T_NATIVE_INT)) { + data_set.read(tensor->flat().data(), H5::PredType::NATIVE_INT, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) { + data_set.read(tensor->flat().data(), H5::PredType::NATIVE_UINT32, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_LONG)) { + data_set.read(tensor->flat().data(), H5::PredType::NATIVE_LONG, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) { + data_set.read(tensor->flat().data(), H5::PredType::NATIVE_FLOAT, memory_space, data_space); + } else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) { + data_set.read(tensor->flat().data(), H5::PredType::NATIVE_DOUBLE, memory_space, data_space); + } else { + return errors::Unimplemented("data type not supported yet: ", data_set.getTypeClass()); + } + } catch(H5::FileIException e){ + return errors::InvalidArgument("unable to open dataset", e.getCDetailMsg()); + } + + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("HDF5Indexable"); + } + private: + mutable mutex mu_; + Env* env_ GUARDED_BY(mu_); + std::unique_ptr file_ GUARDED_BY(mu_); + uint64 file_size_ GUARDED_BY(mu_); + std::unique_ptr file_image_; + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::unordered_map columns_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("HDF5IndexableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("HDF5IndexableSpec").Device(DEVICE_CPU), + IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("HDF5IndexableGetItem").Device(DEVICE_CPU), + IOIndexableGetItemOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow_io/hdf5/ops/hdf5_ops.cc b/tensorflow_io/hdf5/ops/hdf5_ops.cc index c3229697d..9745b207c 100644 --- a/tensorflow_io/hdf5/ops/hdf5_ops.cc +++ b/tensorflow_io/hdf5/ops/hdf5_ops.cc @@ -19,6 +19,46 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("HDF5IndexableInit") + .Input("input: string") + .Output("output: resource") + .Output("component: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); +REGISTER_OP("HDF5IndexableSpec") + .Input("input: resource") + .Input("component: string") + .Output("shape: int64") + .Output("dtype: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("HDF5IndexableGetItem") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Input("step: int64") + .Input("component: string") + .Output("output: dtype") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); + REGISTER_OP("ListHDF5Datasets") .Input("filename: string") .Input("memory: string") diff --git a/tensorflow_io/hdf5/python/ops/hdf5_ops.py b/tensorflow_io/hdf5/python/ops/hdf5_ops.py index 6c20777af..7d7fd4750 100644 --- a/tensorflow_io/hdf5/python/ops/hdf5_ops.py +++ b/tensorflow_io/hdf5/python/ops/hdf5_ops.py @@ -17,10 +17,18 @@ from __future__ import division from __future__ import print_function +import warnings + import tensorflow as tf from tensorflow_io.core.python.ops import core_ops from tensorflow_io.core.python.ops import data_ops +warnings.warn( + "The tensorflow_io.hdf5.HDF5Dataset is " + "deprecated. Please look for tfio.IOTensor.from_hdf5 " + "for reading HDF5 files into tensorflow.", + DeprecationWarning) + def list_hdf5_datasets(filename, **kwargs): """list_hdf5_datasets""" if not tf.executing_eagerly(): diff --git a/tests/test_hdf5_eager.py b/tests/test_hdf5_eager.py index 50c74bd45..ef40801cb 100644 --- a/tests/test_hdf5_eager.py +++ b/tests/test_hdf5_eager.py @@ -24,7 +24,7 @@ import tensorflow as tf if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")): tf.compat.v1.enable_eager_execution() -import tensorflow_io.hdf5 as hdf5_io # pylint: disable=wrong-import-position +import tensorflow_io as tfio # pylint: disable=wrong-import-position def test_hdf5_list_dataset(): """test_hdf5_list_dataset""" @@ -35,11 +35,11 @@ def test_hdf5_list_dataset(): # Without file:// file will be opened directly, otherwise # file will be opened in memory. for filename in [filename, "file://" + filename]: - specs = hdf5_io.list_hdf5_datasets(filename) - assert specs['/group1/dset1'].dtype == tf.int32 - assert specs['/group1/dset1'].shape == tf.TensorShape([1, 1]) - assert specs['/group1/group3/dset2'].dtype == tf.int32 - assert specs['/group1/group3/dset2'].shape == tf.TensorShape([1, 1]) + hdf5 = tfio.IOTensor.from_hdf5(filename) + assert hdf5('/group1/dset1').dtype == tf.int32 + assert hdf5('/group1/dset1').shape == [1, 1] + assert hdf5('/group1/group3/dset2').dtype == tf.int32 + assert hdf5('/group1/group3/dset2').shape == [1, 1] def test_hdf5_read_dataset(): """test_hdf5_list_dataset""" @@ -48,21 +48,18 @@ def test_hdf5_read_dataset(): "test_hdf5", "tdset.h5") for filename in [filename, "file://" + filename]: - specs = hdf5_io.list_hdf5_datasets(filename) - assert specs['/dset1'].dtype == tf.int32 - assert specs['/dset1'].shape == tf.TensorShape([10, 20]) - assert specs['/dset2'].dtype == tf.float64 - assert specs['/dset2'].shape == tf.TensorShape([30, 20]) + hdf5 = tfio.IOTensor.from_hdf5(filename) + assert hdf5('/dset1').dtype == tf.int32 + assert hdf5('/dset1').shape == [10, 20] + assert hdf5('/dset2').dtype == tf.float64 + assert hdf5('/dset2').shape == [30, 20] - p1 = hdf5_io.read_hdf5(filename, specs['/dset1']) - assert p1.dtype == tf.int32 - assert p1.shape == tf.TensorShape([10, 20]) + p1 = hdf5('/dset1') for i in range(10): vv = list([np.asarray([v for v in range(i, i + 20)])]) assert np.all(p1[i].numpy() == vv) - dataset = hdf5_io.HDF5Dataset(filename, '/dset1').apply( - tf.data.experimental.unbatch()) + dataset = tfio.IOTensor.from_hdf5(filename)('/dset1').to_dataset() i = 0 for p in dataset: vv = list([np.asarray([v for v in range(i, i + 20)])])