Skip to content

Commit ff56c89

Browse files
authored
Add tfio.IOTensor.from_hdf5 support (#441)
HDF5 file is a widely used format. It normally stores data into each named `dataset` which is a block of array with shape. It is not exactly columnar as different `dataset` in HDF5 could have different shapes unrelated to each other. From that standpoint it is more like a storage for collections of tensors (where each `dataset` represent one `tensor`). HDF5 does allow slicing and indexing. In fact, the slicing and indexing in HDF5 are much more powerful than many other formats. This PR adds tfio.IOTensor.from_hdf5. It treats HDF5 as a collection of BaseIOTensor which could be further used for slicing and indexing. Note the `collection` here essentially is just a dictionary of key with BaseIOTensor as the value. It is different from Columnar IOTensor's case like Parquet or Avro. Signed-off-by: Yong Tang <[email protected]>
1 parent 5ec2680 commit ff56c89

File tree

7 files changed

+325
-16
lines changed

7 files changed

+325
-16
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""HDF5IOTensor"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import uuid
21+
22+
import tensorflow as tf
23+
from tensorflow_io.core.python.ops import io_tensor_ops
24+
from tensorflow_io.core.python.ops import core_ops
25+
26+
class HDF5IOTensor(io_tensor_ops._CollectionIOTensor): # pylint: disable=protected-access
27+
"""HDF5IOTensor"""
28+
29+
#=============================================================================
30+
# Constructor (private)
31+
#=============================================================================
32+
def __init__(self,
33+
filename,
34+
internal=False):
35+
with tf.name_scope("HDF5IOTensor") as scope:
36+
resource, columns = core_ops.hdf5_indexable_init(
37+
filename,
38+
container=scope,
39+
shared_name="%s/%s" % (filename, uuid.uuid4().hex))
40+
columns = [column.decode() for column in columns.numpy().tolist()]
41+
spec = []
42+
for column in columns:
43+
shape, dtype = core_ops.hdf5_indexable_spec(resource, column)
44+
shape = tf.TensorShape(shape)
45+
dtype = tf.as_dtype(dtype.numpy())
46+
spec.append(tf.TensorSpec(shape, dtype, column))
47+
spec = tuple(spec)
48+
super(HDF5IOTensor, self).__init__(
49+
spec, columns,
50+
resource, core_ops.hdf5_indexable_get_item,
51+
internal=internal)

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow_io.core.python.ops import io_tensor_ops
2222
from tensorflow_io.core.python.ops import audio_io_tensor_ops
2323
from tensorflow_io.core.python.ops import json_io_tensor_ops
24+
from tensorflow_io.core.python.ops import hdf5_io_tensor_ops
2425
from tensorflow_io.core.python.ops import kafka_io_tensor_ops
2526
from tensorflow_io.core.python.ops import lmdb_io_tensor_ops
2627
from tensorflow_io.core.python.ops import prometheus_io_tensor_ops
@@ -346,3 +347,20 @@ def from_lmdb(cls,
346347
"""
347348
with tf.name_scope(kwargs.get("name", "IOFromLMDB")):
348349
return lmdb_io_tensor_ops.LMDBIOTensor(filename, internal=True)
350+
351+
@classmethod
352+
def from_hdf5(cls,
353+
filename,
354+
**kwargs):
355+
"""Creates an `IOTensor` from an hdf5 file.
356+
357+
Args:
358+
filename: A string, the filename of an hdf5 file.
359+
name: A name prefix for the IOTensor (optional).
360+
361+
Returns:
362+
A `IOTensor`.
363+
364+
"""
365+
with tf.name_scope(kwargs.get("name", "IOFromHDF5")):
366+
return hdf5_io_tensor_ops.HDF5IOTensor(filename, internal=True)

tensorflow_io/core/python/ops/io_tensor_ops.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,48 @@ def __call__(self, column):
316316
spec, self._resource, self._function,
317317
component=column, internal=True)
318318

319+
class _CollectionIOTensor(_IOTensor):
320+
"""_CollectionIOTensor
321+
322+
`CollectionIOTensor` is differnt from `TableIOTensor` in that each
323+
component could have different shapes. While additional table-wide
324+
operations are planned to be supported for `TableIOTensor` so that
325+
the same operations could be applied to every column, there is no plan
326+
to support the same in `CollectionIOTensor`. In other words,
327+
`CollectionIOTensor` is only a dictionary with values consisting
328+
of `BaseIOTensor`.
329+
"""
330+
331+
def __init__(self,
332+
spec,
333+
keys,
334+
resource,
335+
function,
336+
internal=False):
337+
self._keys = keys
338+
self._resource = resource
339+
self._function = function
340+
super(_CollectionIOTensor, self).__init__(
341+
spec, keys, internal=internal)
342+
343+
#=============================================================================
344+
# Accessors
345+
#=============================================================================
346+
347+
@property
348+
def keys(self):
349+
"""The names of columns"""
350+
return self._keys
351+
352+
def __call__(self, key):
353+
"""Return a BaseIOTensor with key named `key`"""
354+
key_index = self.keys.index(
355+
next(e for e in self.keys if e == key))
356+
spec = tf.nest.flatten(self.spec)[key_index]
357+
return BaseIOTensor(
358+
spec, self._resource, self._function,
359+
component=key, internal=True)
360+
319361
class _SeriesIOTensor(_IOTensor):
320362
"""_SeriesIOTensor"""
321363

tensorflow_io/hdf5/kernels/hdf5_kernels.cc

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow_io/core/kernels/io_interface.h"
18+
#include "tensorflow_io/core/kernels/stream.h"
1719

1820
#include <hdf5.h>
1921
#include <hdf5_hl.h>
@@ -320,5 +322,156 @@ REGISTER_KERNEL_BUILDER(Name("ReadHDF5").Device(DEVICE_CPU),
320322

321323

322324
} // namespace
325+
326+
327+
class HDF5Indexable : public IOIndexableInterface {
328+
public:
329+
HDF5Indexable(Env* env)
330+
: env_(env) {}
331+
332+
~HDF5Indexable() {}
333+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
334+
if (input.size() > 1) {
335+
return errors::InvalidArgument("more than 1 filename is not supported");
336+
}
337+
const string& filename = input[0];
338+
file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
339+
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));
340+
341+
file_image_.reset(new HDF5FileImage(env_, filename, ""));
342+
H5::H5File *file = file_image_->GetFile();
343+
if (file == nullptr) {
344+
return errors::InvalidArgument("unable to open hdf5 file: ", filename);
345+
}
346+
347+
H5O_info_t info;
348+
file->getObjinfo(info);
349+
HDF5Iterate data(info.addr);
350+
herr_t err = H5Literate(file->getId(), H5_INDEX_NAME, H5_ITER_NATIVE, NULL, HDF5Iterate::Iterate, (void *)&data);
351+
for (size_t i = 0; i < data.datasets_.size(); i++) {
352+
columns_.emplace_back(data.datasets_[i]);
353+
columns_index_[data.datasets_[i]] = i;
354+
}
355+
356+
for (size_t i = 0; i < columns_.size(); i++) {
357+
::tensorflow::DataType dtype;
358+
string dataset = columns_[i];
359+
H5::DataSet data_set = file->openDataSet(dataset);
360+
361+
H5::DataSpace data_space = data_set.getSpace();
362+
int rank = data_space.getSimpleExtentNdims();
363+
absl::InlinedVector<hsize_t, 4> dims(rank);
364+
data_space.getSimpleExtentDims(dims.data());
365+
366+
H5::DataType data_type = data_set.getDataType();
367+
hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND);
368+
if (H5Tequal(native_type, H5T_NATIVE_INT)) {
369+
dtype = DT_INT32;
370+
} else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) {
371+
dtype = DT_UINT32;
372+
} else if (H5Tequal(native_type, H5T_NATIVE_LONG)) {
373+
dtype = DT_INT64;
374+
} else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) {
375+
dtype = DT_FLOAT;
376+
} else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) {
377+
dtype = DT_DOUBLE;
378+
} else {
379+
return errors::InvalidArgument("unsupported data type: ", native_type);
380+
}
381+
dtypes_.emplace_back(dtype);
382+
absl::InlinedVector<int64, 4> shape_dims(rank);
383+
for (int r = 0; r < rank; r++) {
384+
shape_dims[r] = dims[r];
385+
}
386+
shapes_.emplace_back(TensorShape(shape_dims));
387+
}
388+
return Status::OK();
389+
}
390+
Status Component(Tensor* component) override {
391+
*component = Tensor(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
392+
for (size_t i = 0; i < columns_.size(); i++) {
393+
component->flat<string>()(i) = columns_[i];
394+
}
395+
return Status::OK();
396+
}
397+
Status Spec(const Tensor& component, PartialTensorShape* shape, DataType* dtype) override {
398+
const int64 column_index = columns_index_[component.scalar<string>()()];
399+
*shape = shapes_[column_index];
400+
*dtype = dtypes_[column_index];
401+
return Status::OK();
402+
}
403+
404+
Status GetItem(const int64 start, const int64 stop, const int64 step, const Tensor& component, Tensor* tensor) override {
405+
if (step != 1) {
406+
return errors::InvalidArgument("step ", step, " is not supported");
407+
}
408+
const string& column = component.scalar<string>()();
409+
410+
H5::H5File *file = file_image_->GetFile();
411+
try {
412+
H5::DataSet data_set = file->openDataSet(column);
413+
H5::DataSpace data_space = data_set.getSpace();
414+
415+
int rank = data_space.getSimpleExtentNdims();
416+
absl::InlinedVector<hsize_t, 4> dims(rank);
417+
data_space.getSimpleExtentDims(dims.data());
418+
419+
if (start > dims[0] || stop > dims[0]) {
420+
return errors::InvalidArgument("dataset ", column, " selection is out of boundary");
421+
}
422+
// Find the border of the dims start and dims
423+
absl::InlinedVector<hsize_t, 4> dims_start(dims.size(), 0);
424+
dims_start[0] = start;
425+
dims[0] = stop - start;
426+
427+
H5::DataSpace memory_space(dims.size(), dims.data());
428+
429+
data_space.selectHyperslab(H5S_SELECT_SET, dims.data(), dims_start.data());
430+
431+
H5::DataType data_type = data_set.getDataType();
432+
hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND);
433+
if (H5Tequal(native_type, H5T_NATIVE_INT)) {
434+
data_set.read(tensor->flat<int32>().data(), H5::PredType::NATIVE_INT, memory_space, data_space);
435+
} else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) {
436+
data_set.read(tensor->flat<uint32>().data(), H5::PredType::NATIVE_UINT32, memory_space, data_space);
437+
} else if (H5Tequal(native_type, H5T_NATIVE_LONG)) {
438+
data_set.read(tensor->flat<int64>().data(), H5::PredType::NATIVE_LONG, memory_space, data_space);
439+
} else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) {
440+
data_set.read(tensor->flat<float>().data(), H5::PredType::NATIVE_FLOAT, memory_space, data_space);
441+
} else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) {
442+
data_set.read(tensor->flat<double>().data(), H5::PredType::NATIVE_DOUBLE, memory_space, data_space);
443+
} else {
444+
return errors::Unimplemented("data type not supported yet: ", data_set.getTypeClass());
445+
}
446+
} catch(H5::FileIException e){
447+
return errors::InvalidArgument("unable to open dataset", e.getCDetailMsg());
448+
}
449+
450+
return Status::OK();
451+
}
452+
453+
string DebugString() const override {
454+
mutex_lock l(mu_);
455+
return strings::StrCat("HDF5Indexable");
456+
}
457+
private:
458+
mutable mutex mu_;
459+
Env* env_ GUARDED_BY(mu_);
460+
std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY(mu_);
461+
uint64 file_size_ GUARDED_BY(mu_);
462+
std::unique_ptr<HDF5FileImage> file_image_;
463+
464+
std::vector<DataType> dtypes_;
465+
std::vector<TensorShape> shapes_;
466+
std::vector<string> columns_;
467+
std::unordered_map<string, int64> columns_index_;
468+
};
469+
470+
REGISTER_KERNEL_BUILDER(Name("HDF5IndexableInit").Device(DEVICE_CPU),
471+
IOInterfaceInitOp<HDF5Indexable>);
472+
REGISTER_KERNEL_BUILDER(Name("HDF5IndexableSpec").Device(DEVICE_CPU),
473+
IOInterfaceSpecOp<HDF5Indexable>);
474+
REGISTER_KERNEL_BUILDER(Name("HDF5IndexableGetItem").Device(DEVICE_CPU),
475+
IOIndexableGetItemOp<HDF5Indexable>);
323476
} // namespace data
324477
} // namespace tensorflow

tensorflow_io/hdf5/ops/hdf5_ops.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,46 @@ limitations under the License.
1919

2020
namespace tensorflow {
2121

22+
REGISTER_OP("HDF5IndexableInit")
23+
.Input("input: string")
24+
.Output("output: resource")
25+
.Output("component: string")
26+
.Attr("container: string = ''")
27+
.Attr("shared_name: string = ''")
28+
.SetShapeFn([](shape_inference::InferenceContext* c) {
29+
c->set_output(0, c->Scalar());
30+
c->set_output(1, c->MakeShape({}));
31+
return Status::OK();
32+
});
33+
REGISTER_OP("HDF5IndexableSpec")
34+
.Input("input: resource")
35+
.Input("component: string")
36+
.Output("shape: int64")
37+
.Output("dtype: int64")
38+
.SetShapeFn([](shape_inference::InferenceContext* c) {
39+
c->set_output(0, c->MakeShape({c->UnknownDim()}));
40+
c->set_output(1, c->MakeShape({}));
41+
return Status::OK();
42+
});
43+
44+
REGISTER_OP("HDF5IndexableGetItem")
45+
.Input("input: resource")
46+
.Input("start: int64")
47+
.Input("stop: int64")
48+
.Input("step: int64")
49+
.Input("component: string")
50+
.Output("output: dtype")
51+
.Attr("shape: shape")
52+
.Attr("dtype: type")
53+
.SetShapeFn([](shape_inference::InferenceContext* c) {
54+
PartialTensorShape shape;
55+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
56+
shape_inference::ShapeHandle entry;
57+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry));
58+
c->set_output(0, entry);
59+
return Status::OK();
60+
});
61+
2262
REGISTER_OP("ListHDF5Datasets")
2363
.Input("filename: string")
2464
.Input("memory: string")

tensorflow_io/hdf5/python/ops/hdf5_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,18 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import warnings
21+
2022
import tensorflow as tf
2123
from tensorflow_io.core.python.ops import core_ops
2224
from tensorflow_io.core.python.ops import data_ops
2325

26+
warnings.warn(
27+
"The tensorflow_io.hdf5.HDF5Dataset is "
28+
"deprecated. Please look for tfio.IOTensor.from_hdf5 "
29+
"for reading HDF5 files into tensorflow.",
30+
DeprecationWarning)
31+
2432
def list_hdf5_datasets(filename, **kwargs):
2533
"""list_hdf5_datasets"""
2634
if not tf.executing_eagerly():

0 commit comments

Comments
 (0)