Skip to content

Commit 49d7203

Browse files
committed
Add tfio.IOTensor.from_hdf5 support
Note: this PR depends on PR 437. HDF5 file is a widely used format. It normally stores data into each named `dataset` which is a block of array with shape. It is not exactly columnar as different `dataset` in HDF5 could have different shapes unrelated to each other. From that standpoint it is more like a storage for collections of tensors (where each `dataset` represent one `tensor`). HDF5 does allow slicing and indexing. In fact, the slicing and indexing in HDF5 are much more powerful than many other formats. This PR adds tfio.IOTensor.from_avro support to allow easy indexing and slicing in TF. It creates a _TableIOTensor which assumes HDF5 to be tabular. This may not be the best format though, so some further enhancement might be needed to better position HDF5 format with TF. Signed-off-by: Yong Tang <[email protected]>
1 parent 525d55f commit 49d7203

File tree

6 files changed

+292
-16
lines changed

6 files changed

+292
-16
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""HDF5IOTensor"""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import uuid
21+
22+
import tensorflow as tf
23+
from tensorflow_io.core.python.ops import io_tensor_ops
24+
from tensorflow_io.core.python.ops import core_ops
25+
26+
class HDF5IOTensor(io_tensor_ops._TableIOTensor): # pylint: disable=protected-access
27+
"""HDF5IOTensor"""
28+
29+
#=============================================================================
30+
# Constructor (private)
31+
#=============================================================================
32+
def __init__(self,
33+
filename,
34+
columns=None,
35+
internal=False):
36+
with tf.name_scope("HDF5IOTensor") as scope:
37+
metadata = []
38+
if columns is not None:
39+
metadata.extend(["column: "+column for column in columns])
40+
resource, dtypes, shapes, columns = core_ops.hdf5_indexable_init(
41+
filename, metadata=metadata,
42+
container=scope,
43+
shared_name="%s/%s" % (filename, uuid.uuid4().hex))
44+
self._filename = filename
45+
super(HDF5IOTensor, self).__init__(
46+
shapes, dtypes, columns, filename,
47+
resource, core_ops.hdf5_indexable_get_item,
48+
internal=internal)

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow_io.core.python.ops import io_tensor_ops
2222
from tensorflow_io.core.python.ops import audio_io_tensor_ops
2323
from tensorflow_io.core.python.ops import json_io_tensor_ops
24+
from tensorflow_io.core.python.ops import hdf5_io_tensor_ops
2425
from tensorflow_io.core.python.ops import kafka_io_tensor_ops
2526

2627
class IOTensor(io_tensor_ops._BaseIOTensor): # pylint: disable=protected-access
@@ -264,3 +265,20 @@ def from_kafka(cls,
264265
servers=kwargs.get("servers", None),
265266
configuration=kwargs.get("configuration", None),
266267
internal=True)
268+
269+
@classmethod
270+
def from_hdf5(cls,
271+
filename,
272+
**kwargs):
273+
"""Creates an `IOTensor` from an hdf5 file.
274+
275+
Args:
276+
filename: A string, the filename of an hdf5 file.
277+
name: A name prefix for the IOTensor (optional).
278+
279+
Returns:
280+
A `IOTensor`.
281+
282+
"""
283+
with tf.name_scope(kwargs.get("name", "IOFromHDF5")):
284+
return hdf5_io_tensor_ops.HDF5IOTensor(filename, internal=True)

tensorflow_io/hdf5/kernels/hdf5_kernels.cc

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow_io/core/kernels/io_interface.h"
18+
#include "tensorflow_io/core/kernels/stream.h"
1719

1820
#include <hdf5.h>
1921
#include <hdf5_hl.h>
@@ -320,5 +322,168 @@ REGISTER_KERNEL_BUILDER(Name("ReadHDF5").Device(DEVICE_CPU),
320322

321323

322324
} // namespace
325+
326+
327+
class HDF5Indexable : public IOIndexableInterface {
328+
public:
329+
HDF5Indexable(Env* env)
330+
: env_(env) {}
331+
332+
~HDF5Indexable() {}
333+
Status Init(const std::vector<string>& input, const std::vector<string>& metadata, const void* memory_data, const int64 memory_size) override {
334+
if (input.size() > 1) {
335+
return errors::InvalidArgument("more than 1 filename is not supported");
336+
}
337+
const string& filename = input[0];
338+
file_.reset(new SizedRandomAccessFile(env_, filename, memory_data, memory_size));
339+
TF_RETURN_IF_ERROR(file_->GetFileSize(&file_size_));
340+
341+
file_image_.reset(new HDF5FileImage(env_, filename, ""));
342+
H5::H5File *file = file_image_->GetFile();
343+
if (file == nullptr) {
344+
return errors::InvalidArgument("unable to open hdf5 file: ", filename);
345+
}
346+
347+
for (size_t i = 0; i < metadata.size(); i++) {
348+
if (metadata[i].find_first_of("column: ") == 0) {
349+
columns_.emplace_back(metadata[i].substr(8));
350+
}
351+
}
352+
353+
if (columns_.size() == 0) {
354+
H5O_info_t info;
355+
file->getObjinfo(info);
356+
HDF5Iterate data(info.addr);
357+
herr_t err = H5Literate(file->getId(), H5_INDEX_NAME, H5_ITER_NATIVE, NULL, HDF5Iterate::Iterate, (void *)&data);
358+
for (size_t i = 0; i < data.datasets_.size(); i++) {
359+
columns_.emplace_back(data.datasets_[i]);
360+
}
361+
}
362+
363+
for (size_t i = 0; i < columns_.size(); i++) {
364+
::tensorflow::DataType dtype;
365+
string dataset = columns_[i];
366+
H5::DataSet data_set = file->openDataSet(dataset);
367+
368+
H5::DataSpace data_space = data_set.getSpace();
369+
int rank = data_space.getSimpleExtentNdims();
370+
absl::InlinedVector<hsize_t, 4> dims(rank);
371+
data_space.getSimpleExtentDims(dims.data());
372+
373+
H5::DataType data_type = data_set.getDataType();
374+
hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND);
375+
if (H5Tequal(native_type, H5T_NATIVE_INT)) {
376+
dtype = DT_INT32;
377+
} else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) {
378+
dtype = DT_UINT32;
379+
} else if (H5Tequal(native_type, H5T_NATIVE_LONG)) {
380+
dtype = DT_INT64;
381+
} else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) {
382+
dtype = DT_FLOAT;
383+
} else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) {
384+
dtype = DT_DOUBLE;
385+
} else {
386+
return errors::InvalidArgument("unsupported data type: ", native_type);
387+
}
388+
dtypes_.emplace_back(dtype);
389+
absl::InlinedVector<int64, 4> shape_dims(rank);
390+
for (int r = 0; r < rank; r++) {
391+
shape_dims[r] = dims[r];
392+
}
393+
shapes_.emplace_back(TensorShape(shape_dims));
394+
}
395+
return Status::OK();
396+
}
397+
Status Spec(std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
398+
dtypes.clear();
399+
for (size_t i = 0; i < dtypes_.size(); i++) {
400+
dtypes.push_back(dtypes_[i]);
401+
}
402+
shapes.clear();
403+
for (size_t i = 0; i < shapes_.size(); i++) {
404+
shapes.push_back(shapes_[i]);
405+
}
406+
return Status::OK();
407+
}
408+
409+
Status Extra(std::vector<Tensor>* extra) override {
410+
// Expose columns
411+
Tensor columns(DT_STRING, TensorShape({static_cast<int64>(columns_.size())}));
412+
for (size_t i = 0; i < columns_.size(); i++) {
413+
columns.flat<string>()(i) = columns_[i];
414+
}
415+
extra->push_back(columns);
416+
return Status::OK();
417+
}
418+
419+
Status GetItem(const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
420+
if (step != 1) {
421+
return errors::InvalidArgument("step ", step, " is not supported");
422+
}
423+
H5::H5File *file = file_image_->GetFile();
424+
for (size_t i = 0; i < tensors.size(); i++) {
425+
try {
426+
H5::DataSet data_set = file->openDataSet(columns_[i]);
427+
H5::DataSpace data_space = data_set.getSpace();
428+
429+
int rank = data_space.getSimpleExtentNdims();
430+
absl::InlinedVector<hsize_t, 4> dims(rank);
431+
data_space.getSimpleExtentDims(dims.data());
432+
433+
if (start > dims[0] || stop > dims[0]) {
434+
return errors::InvalidArgument("dataset ", columns_[i], " selection is out of boundary");
435+
}
436+
// Find the border of the dims start and dims
437+
absl::InlinedVector<hsize_t, 4> dims_start(dims.size(), 0);
438+
dims_start[0] = start;
439+
dims[0] = stop - start;
440+
441+
H5::DataSpace memory_space(dims.size(), dims.data());
442+
443+
data_space.selectHyperslab(H5S_SELECT_SET, dims.data(), dims_start.data());
444+
445+
H5::DataType data_type = data_set.getDataType();
446+
hid_t native_type = H5Tget_native_type(data_type.getId(), H5T_DIR_ASCEND);
447+
if (H5Tequal(native_type, H5T_NATIVE_INT)) {
448+
data_set.read(tensors[i].flat<int32>().data(), H5::PredType::NATIVE_INT, memory_space, data_space);
449+
} else if (H5Tequal(native_type, H5T_NATIVE_UINT32)) {
450+
data_set.read(tensors[i].flat<uint32>().data(), H5::PredType::NATIVE_UINT32, memory_space, data_space);
451+
} else if (H5Tequal(native_type, H5T_NATIVE_LONG)) {
452+
data_set.read(tensors[i].flat<int64>().data(), H5::PredType::NATIVE_LONG, memory_space, data_space);
453+
} else if (H5Tequal(native_type, H5T_NATIVE_FLOAT)) {
454+
data_set.read(tensors[i].flat<float>().data(), H5::PredType::NATIVE_FLOAT, memory_space, data_space);
455+
} else if (H5Tequal(native_type, H5T_NATIVE_DOUBLE)) {
456+
data_set.read(tensors[i].flat<double>().data(), H5::PredType::NATIVE_DOUBLE, memory_space, data_space);
457+
} else {
458+
return errors::Unimplemented("data type not supported yet: ", data_set.getTypeClass());
459+
}
460+
} catch(H5::FileIException e){
461+
return errors::InvalidArgument("unable to open dataset", e.getCDetailMsg());
462+
}
463+
}
464+
465+
return Status::OK();
466+
}
467+
468+
string DebugString() const override {
469+
mutex_lock l(mu_);
470+
return strings::StrCat("HDF5Indexable");
471+
}
472+
private:
473+
mutable mutex mu_;
474+
Env* env_ GUARDED_BY(mu_);
475+
std::unique_ptr<SizedRandomAccessFile> file_ GUARDED_BY(mu_);
476+
uint64 file_size_ GUARDED_BY(mu_);
477+
std::unique_ptr<HDF5FileImage> file_image_;
478+
479+
std::vector<DataType> dtypes_;
480+
std::vector<TensorShape> shapes_;
481+
std::vector<string> columns_;
482+
};
483+
484+
REGISTER_KERNEL_BUILDER(Name("HDF5IndexableInit").Device(DEVICE_CPU),
485+
IOInterfaceInitOp<HDF5Indexable>);
486+
REGISTER_KERNEL_BUILDER(Name("HDF5IndexableGetItem").Device(DEVICE_CPU),
487+
IOIndexableGetItemOp<HDF5Indexable>);
323488
} // namespace data
324489
} // namespace tensorflow

tensorflow_io/hdf5/ops/hdf5_ops.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,46 @@ limitations under the License.
1919

2020
namespace tensorflow {
2121

22+
REGISTER_OP("HDF5IndexableInit")
23+
.Input("input: string")
24+
.Input("metadata: string")
25+
.Output("output: resource")
26+
.Output("dtypes: int64")
27+
.Output("shapes: int64")
28+
.Output("columns: string")
29+
.Attr("container: string = ''")
30+
.Attr("shared_name: string = ''")
31+
.SetIsStateful()
32+
.SetShapeFn([](shape_inference::InferenceContext* c) {
33+
c->set_output(0, c->Scalar());
34+
c->set_output(1, c->MakeShape({c->UnknownDim()}));
35+
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
36+
c->set_output(3, c->MakeShape({c->UnknownDim()}));
37+
return Status::OK();
38+
});
39+
40+
REGISTER_OP("HDF5IndexableGetItem")
41+
.Input("input: resource")
42+
.Input("start: int64")
43+
.Input("stop: int64")
44+
.Input("step: int64")
45+
.Output("output: dtype")
46+
.Attr("dtype: list(type) >= 1")
47+
.Attr("shape: list(shape) >= 1")
48+
.SetShapeFn([](shape_inference::InferenceContext* c) {
49+
std::vector<PartialTensorShape> shape;
50+
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
51+
if (shape.size() != c->num_outputs()) {
52+
return errors::InvalidArgument("`shape` must be the same length as `types` (", shape.size(), " vs. ", c->num_outputs());
53+
}
54+
for (size_t i = 0; i < shape.size(); ++i) {
55+
shape_inference::ShapeHandle entry;
56+
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape[i], &entry));
57+
c->set_output(static_cast<int64>(i), entry);
58+
}
59+
return Status::OK();
60+
});
61+
2262
REGISTER_OP("ListHDF5Datasets")
2363
.Input("filename: string")
2464
.Input("memory: string")

tensorflow_io/hdf5/python/ops/hdf5_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,18 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import warnings
21+
2022
import tensorflow as tf
2123
from tensorflow_io.core.python.ops import core_ops
2224
from tensorflow_io.core.python.ops import data_ops
2325

26+
warnings.warn(
27+
"The tensorflow_io.hdf5.HDF5Dataset is "
28+
"deprecated. Please look for tfio.IOTensor.from_hdf5 "
29+
"for reading HDF5 files into tensorflow.",
30+
DeprecationWarning)
31+
2432
def list_hdf5_datasets(filename, **kwargs):
2533
"""list_hdf5_datasets"""
2634
if not tf.executing_eagerly():

tests/test_hdf5_eager.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import tensorflow as tf
2525
if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")):
2626
tf.compat.v1.enable_eager_execution()
27-
import tensorflow_io.hdf5 as hdf5_io # pylint: disable=wrong-import-position
27+
import tensorflow_io as tfio # pylint: disable=wrong-import-position
2828

2929
def test_hdf5_list_dataset():
3030
"""test_hdf5_list_dataset"""
@@ -35,11 +35,11 @@ def test_hdf5_list_dataset():
3535
# Without file:// file will be opened directly, otherwise
3636
# file will be opened in memory.
3737
for filename in [filename, "file://" + filename]:
38-
specs = hdf5_io.list_hdf5_datasets(filename)
39-
assert specs['/group1/dset1'].dtype == tf.int32
40-
assert specs['/group1/dset1'].shape == tf.TensorShape([1, 1])
41-
assert specs['/group1/group3/dset2'].dtype == tf.int32
42-
assert specs['/group1/group3/dset2'].shape == tf.TensorShape([1, 1])
38+
hdf5 = tfio.IOTensor.from_hdf5(filename)
39+
assert hdf5.dtype('/group1/dset1') == tf.int32
40+
assert hdf5.shape('/group1/dset1') == [1, 1]
41+
assert hdf5.dtype('/group1/group3/dset2') == tf.int32
42+
assert hdf5.shape('/group1/group3/dset2') == [1, 1]
4343

4444
def test_hdf5_read_dataset():
4545
"""test_hdf5_list_dataset"""
@@ -48,21 +48,18 @@ def test_hdf5_read_dataset():
4848
"test_hdf5", "tdset.h5")
4949

5050
for filename in [filename, "file://" + filename]:
51-
specs = hdf5_io.list_hdf5_datasets(filename)
52-
assert specs['/dset1'].dtype == tf.int32
53-
assert specs['/dset1'].shape == tf.TensorShape([10, 20])
54-
assert specs['/dset2'].dtype == tf.float64
55-
assert specs['/dset2'].shape == tf.TensorShape([30, 20])
51+
hdf5 = tfio.IOTensor.from_hdf5(filename)
52+
assert hdf5.dtype('/dset1') == tf.int32
53+
assert hdf5.shape('/dset1') == [10, 20]
54+
assert hdf5.dtype('/dset2') == tf.float64
55+
assert hdf5.shape('/dset2') == [30, 20]
5656

57-
p1 = hdf5_io.read_hdf5(filename, specs['/dset1'])
58-
assert p1.dtype == tf.int32
59-
assert p1.shape == tf.TensorShape([10, 20])
57+
p1 = hdf5('/dset1')
6058
for i in range(10):
6159
vv = list([np.asarray([v for v in range(i, i + 20)])])
6260
assert np.all(p1[i].numpy() == vv)
6361

64-
dataset = hdf5_io.HDF5Dataset(filename, '/dset1').apply(
65-
tf.data.experimental.unbatch())
62+
dataset = tfio.IOTensor.from_hdf5(filename)('/dset1').to_dataset()
6663
i = 0
6764
for p in dataset:
6865
vv = list([np.asarray([v for v in range(i, i + 20)])])

0 commit comments

Comments
 (0)