From c3700444b87ad135070653cfc5d9a3ed97711850 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 4 May 2019 14:47:24 +0000 Subject: [PATCH] Update LMDBDataset to support batch at the creation This PR updates LMDBDataset to support batch at the creation, it also switch to the new pattern to reduce unnecessary code duplication. Signed-off-by: Yong Tang Add test case for LMDB with batch Signed-off-by: Yong Tang --- tensorflow_io/core/BUILD | 1 + tensorflow_io/lmdb/BUILD | 5 +- tensorflow_io/lmdb/__init__.py | 2 +- .../lmdb/kernels/lmdb_dataset_ops.cc | 219 ------------------ tensorflow_io/lmdb/kernels/lmdb_input.cc | 158 +++++++++++++ .../lmdb/ops/{dataset_ops.cc => lmdb_ops.cc} | 19 +- .../ops/{lmdb_dataset_ops.py => lmdb_ops.py} | 25 +- tests/test_lmdb.py | 33 +++ 8 files changed, 228 insertions(+), 234 deletions(-) delete mode 100644 tensorflow_io/lmdb/kernels/lmdb_dataset_ops.cc create mode 100644 tensorflow_io/lmdb/kernels/lmdb_input.cc rename tensorflow_io/lmdb/ops/{dataset_ops.cc => lmdb_ops.cc} (65%) rename tensorflow_io/lmdb/python/ops/{lmdb_dataset_ops.py => lmdb_ops.py} (79%) diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 2bc1a82cd..a39db8176 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -16,6 +16,7 @@ cc_library( ".", ], deps = [ + "@libarchive", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", ], diff --git a/tensorflow_io/lmdb/BUILD b/tensorflow_io/lmdb/BUILD index 4e5a01e4f..efcd36833 100644 --- a/tensorflow_io/lmdb/BUILD +++ b/tensorflow_io/lmdb/BUILD @@ -5,8 +5,8 @@ package(default_visibility = ["//visibility:public"]) cc_binary( name = "python/ops/_lmdb_ops.so", srcs = [ - "kernels/lmdb_dataset_ops.cc", - "ops/dataset_ops.cc", + "kernels/lmdb_input.cc", + "ops/lmdb_ops.cc", ], copts = [ "-pthread", @@ -15,6 +15,7 @@ cc_binary( ], linkshared = 1, deps = [ + "//tensorflow_io/core:dataset_ops", "@lmdb", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", diff --git a/tensorflow_io/lmdb/__init__.py b/tensorflow_io/lmdb/__init__.py index b1d5b2de9..9bc363a91 100644 --- a/tensorflow_io/lmdb/__init__.py +++ b/tensorflow_io/lmdb/__init__.py @@ -21,7 +21,7 @@ from __future__ import division from __future__ import print_function -from tensorflow_io.lmdb.python.ops.lmdb_dataset_ops import LMDBDataset +from tensorflow_io.lmdb.python.ops.lmdb_ops import LMDBDataset from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow_io/lmdb/kernels/lmdb_dataset_ops.cc b/tensorflow_io/lmdb/kernels/lmdb_dataset_ops.cc deleted file mode 100644 index 84a52bf31..000000000 --- a/tensorflow_io/lmdb/kernels/lmdb_dataset_ops.cc +++ /dev/null @@ -1,219 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "tensorflow/core/platform/file_system.h" - -#include "lmdb.h" // NOLINT(build/include) - -namespace tensorflow { -namespace data { -namespace { - -class LMDBDatasetOp : public DatasetOpKernel { - public: - using DatasetOpKernel::DatasetOpKernel; - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - const Tensor* filenames_tensor; - OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); - - std::vector filenames; - filenames.reserve(filenames_tensor->NumElements()); - for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); - } - - *output = new Dataset(ctx, filenames); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const std::vector& filenames) - : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::LMDB")}); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = - new DataTypeVector({DT_STRING, DT_STRING}); - return *dtypes; - } - - const std::vector& output_shapes() const override { - static std::vector* shapes = - new std::vector({{}, {}}); - return *shapes; - } - - string DebugString() const override { return "LMDBDatasetOp::Dataset"; } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - do { - if (mdb_cursor_) { - out_tensors->emplace_back(ctx->allocator({}), DT_STRING, - TensorShape({})); - Tensor& key_tensor = out_tensors->back(); - key_tensor.scalar()() = string( - static_cast(mdb_key_.mv_data), mdb_key_.mv_size); - - out_tensors->emplace_back(ctx->allocator({}), DT_STRING, - TensorShape({})); - Tensor& value_tensor = out_tensors->back(); - value_tensor.scalar()() = - string(static_cast(mdb_value_.mv_data), - mdb_value_.mv_size); - - int val; - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - ++current_file_index_; - } - *end_of_sequence = false; - return Status::OK(); - } - if (current_file_index_ == dataset()->filenames_.size()) { - *end_of_sequence = true; - return Status::OK(); - } - - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - } while (true); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); - } - - private: - Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (current_file_index_ >= dataset()->filenames_.size()) { - return errors::InvalidArgument( - "current_file_index_:", current_file_index_, - " >= filenames_.size():", dataset()->filenames_.size()); - } - const string& filename = dataset()->filenames_[current_file_index_]; - - int val = mdb_env_create(&mdb_env_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; - - struct stat source_stat; - if (stat(filename.c_str(), &source_stat) == 0 && - (source_stat.st_mode & S_IFREG)) { - flags |= MDB_NOSUBDIR; - } - val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - } - return Status::OK(); - } - void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (mdb_env_ != nullptr) { - if (mdb_cursor_) { - mdb_cursor_close(mdb_cursor_); - mdb_cursor_ = nullptr; - } - mdb_dbi_close(mdb_env_, mdb_dbi_); - mdb_txn_abort(mdb_txn_); - mdb_env_close(mdb_env_); - mdb_txn_ = nullptr; - mdb_dbi_ = 0; - mdb_env_ = nullptr; - } - } - mutex mu_; - size_t current_file_index_ GUARDED_BY(mu_) = 0; - MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr; - MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr; - MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0; - MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr; - - MDB_val mdb_key_ GUARDED_BY(mu_); - MDB_val mdb_value_ GUARDED_BY(mu_); - }; - - const std::vector filenames_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), - LMDBDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow_io/lmdb/kernels/lmdb_input.cc b/tensorflow_io/lmdb/kernels/lmdb_input.cc new file mode 100644 index 000000000..5aa1622bc --- /dev/null +++ b/tensorflow_io/lmdb/kernels/lmdb_input.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/dataset_ops.h" +#include "lmdb.h" + +namespace tensorflow { +namespace data { + +class LMDBInputStream{ +public: + explicit LMDBInputStream(const string& filename) + : mdb_status_(MDB_SUCCESS) + , mdb_env_(nullptr) + , mdb_txn_(nullptr) + , mdb_dbi_(0) + , mdb_cursor_(nullptr) { + mdb_status_ = mdb_env_create(&mdb_env_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; + + struct stat source_stat; + if (stat(filename.c_str(), &source_stat) == 0 && + (source_stat.st_mode & S_IFREG)) { + flags |= MDB_NOSUBDIR; + } + mdb_status_ = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); + if (mdb_status_ != MDB_SUCCESS && mdb_status_ != MDB_NOTFOUND) { + return; + } + if (mdb_status_ == MDB_NOTFOUND) { + // empty data, move on. + } + return; + } + ~LMDBInputStream() { + if (mdb_env_ != nullptr) { + if (mdb_cursor_) { + mdb_cursor_close(mdb_cursor_); + mdb_cursor_ = nullptr; + } + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + mdb_env_close(mdb_env_); + mdb_txn_ = nullptr; + mdb_dbi_ = 0; + mdb_env_ = nullptr; + } + } + Status ReadRecord(string* key, string* val) { + if (mdb_status_ != MDB_SUCCESS) { + if (mdb_status_ == MDB_NOTFOUND) { + return errors::OutOfRange("EOF reached"); + } + return errors::InvalidArgument(mdb_strerror(mdb_status_)); + } + *key = std::move(string(static_cast(mdb_key_.mv_data), mdb_key_.mv_size)); + *val = std::move(string(static_cast(mdb_value_.mv_data), mdb_value_.mv_size)); + mdb_status_ = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); + return Status::OK(); + } +private: + int mdb_status_ = MDB_SUCCESS; + MDB_env* mdb_env_ = nullptr; + MDB_txn* mdb_txn_ = nullptr; + MDB_dbi mdb_dbi_ = 0; + + MDB_cursor* mdb_cursor_ = nullptr; + MDB_val mdb_key_; + MDB_val mdb_value_; +}; + +class LMDBInput: public FileInput { + public: + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + if (state.get() == nullptr) { + // LMDB does not support compression or any file system so + // using filename() (instead of stream/s) here is fine. + state.reset(new LMDBInputStream(filename())); + } + std::vector keys; + keys.reserve(record_to_read); + std::vector vals; + vals.reserve(record_to_read); + while ((*record_read) < record_to_read) { + string key, val; + Status status = state.get()->ReadRecord(&key, &val); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return status; + } + if (!status.ok()) { + break; + } + keys.emplace_back(std::move(key)); + vals.emplace_back(std::move(val)); + (*record_read)++; + } + if (*record_read > 0) { + Tensor key_tensor(ctx->allocator({}), DT_STRING, {*record_read}); + Tensor val_tensor(ctx->allocator({}), DT_STRING, {*record_read}); + for (size_t i = 0; i < (*record_read); i++) { + key_tensor.flat()(i) = std::move(keys[i]); + val_tensor.flat()(i) = std::move(vals[i]); + } + out_tensors->emplace_back(std::move(key_tensor)); + out_tensors->emplace_back(std::move(val_tensor)); + } + return Status::OK(); + } + Status FromStream(io::InputStreamInterface* s) override { + return Status::OK(); + } + void EncodeAttributes(VariantTensorData* data) const override { + } + bool DecodeAttributes(const VariantTensorData& data) override { + return true; + } + protected: +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(LMDBInput, "tensorflow::data::LMDBInput"); + +REGISTER_KERNEL_BUILDER(Name("LMDBInput").Device(DEVICE_CPU), + FileInputOp); +REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), + FileInputDatasetOp); +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/lmdb/ops/dataset_ops.cc b/tensorflow_io/lmdb/ops/lmdb_ops.cc similarity index 65% rename from tensorflow_io/lmdb/ops/dataset_ops.cc rename to tensorflow_io/lmdb/ops/lmdb_ops.cc index 5bcf5c99e..7585c09fd 100644 --- a/tensorflow_io/lmdb/ops/dataset_ops.cc +++ b/tensorflow_io/lmdb/ops/lmdb_ops.cc @@ -19,12 +19,27 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("LMDBInput") + .Input("source: string") + .Output("handle: variant") + .Attr("filters: list(string) = []") + .Attr("columns: list(string) = []") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + REGISTER_OP("LMDBDataset") - .Input("filenames: string") + .Input("input: T") + .Input("batch: int64") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("T: {string, variant} = DT_VARIANT") .SetIsStateful() - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({})); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/lmdb/python/ops/lmdb_dataset_ops.py b/tensorflow_io/lmdb/python/ops/lmdb_ops.py similarity index 79% rename from tensorflow_io/lmdb/python/ops/lmdb_dataset_ops.py rename to tensorflow_io/lmdb/python/ops/lmdb_ops.py index 3d480461b..470d891ca 100644 --- a/tensorflow_io/lmdb/python/ops/lmdb_dataset_ops.py +++ b/tensorflow_io/lmdb/python/ops/lmdb_ops.py @@ -26,7 +26,7 @@ class LMDBDataset(data.Dataset): """A LMDB Dataset that reads the lmdb file.""" - def __init__(self, filenames): + def __init__(self, filenames, batch=None): """Create a `LMDBDataset`. `LMDBDataset` allows a user to read data from a mdb file as @@ -43,8 +43,8 @@ def __init__(self, filenames): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - self._filenames = tensorflow.convert_to_tensor( - filenames, dtype=dtypes.string, name="filenames") + self._data_input = lmdb_ops.lmdb_input(filenames) + self._batch = 0 if batch is None else batch super(LMDBDataset, self).__init__() def _inputs(self): @@ -52,17 +52,22 @@ def _inputs(self): def _as_variant_tensor(self): return lmdb_ops.lmdb_dataset( - self._filenames, - (dtypes.string, dtypes.string), - (tensorflow.TensorShape([]), tensorflow.TensorShape([]))) + self._data_input, + self._batch, + output_types=self.output_types, + output_shapes=self.output_shapes) @property - def output_classes(self): - return tensorflow.Tensor, tensorflow.Tensor + def output_shapes(self): + return ( + tensorflow.TensorShape([]), + tensorflow.TensorShape([])) if self._batch == 0 else ( + tensorflow.TensorShape([None]), + tensorflow.TensorShape([None])) @property - def output_shapes(self): - return (tensorflow.TensorShape([]), tensorflow.TensorShape([])) + def output_classes(self): + return tensorflow.Tensor, tensorflow.Tensor @property def output_types(self): diff --git a/tests/test_lmdb.py b/tests/test_lmdb.py index abdc745c6..58e6644e5 100644 --- a/tests/test_lmdb.py +++ b/tests/test_lmdb.py @@ -59,6 +59,39 @@ def test_read_from_file(self): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def test_read_from_file_with_batch(self): + """test_read_from_file""" + super(LMDBDatasetTest, self).setUp() + # Copy database out because we need the path to be writable to use locks. + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_lmdb", "data.mdb") + self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") + shutil.copy(path, self.db_path) + + filename = self.db_path + + dataset = lmdb_io.LMDBDataset([filename], batch=3) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for i in range(0, 9, 3): + k = [ + str(i).encode(), + str(i + 1).encode(), + str(i + 2).encode()] + v = [ + str(chr(ord("a") + i)).encode(), + str(chr(ord("a") + i + 1)).encode(), + str(chr(ord("a") + i + 2)).encode()] + self.assertAllEqual((k, v), sess.run(get_next)) + self.assertAllEqual( + ([str(9).encode()], [str('j').encode()]), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main()