diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index 2bc1a82cd..a39db8176 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -16,6 +16,7 @@ cc_library( ".", ], deps = [ + "@libarchive", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", ], diff --git a/tensorflow_io/lmdb/BUILD b/tensorflow_io/lmdb/BUILD index 4e5a01e4f..efcd36833 100644 --- a/tensorflow_io/lmdb/BUILD +++ b/tensorflow_io/lmdb/BUILD @@ -5,8 +5,8 @@ package(default_visibility = ["//visibility:public"]) cc_binary( name = "python/ops/_lmdb_ops.so", srcs = [ - "kernels/lmdb_dataset_ops.cc", - "ops/dataset_ops.cc", + "kernels/lmdb_input.cc", + "ops/lmdb_ops.cc", ], copts = [ "-pthread", @@ -15,6 +15,7 @@ cc_binary( ], linkshared = 1, deps = [ + "//tensorflow_io/core:dataset_ops", "@lmdb", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", diff --git a/tensorflow_io/lmdb/__init__.py b/tensorflow_io/lmdb/__init__.py index b1d5b2de9..9bc363a91 100644 --- a/tensorflow_io/lmdb/__init__.py +++ b/tensorflow_io/lmdb/__init__.py @@ -21,7 +21,7 @@ from __future__ import division from __future__ import print_function -from tensorflow_io.lmdb.python.ops.lmdb_dataset_ops import LMDBDataset +from tensorflow_io.lmdb.python.ops.lmdb_ops import LMDBDataset from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow_io/lmdb/kernels/lmdb_dataset_ops.cc b/tensorflow_io/lmdb/kernels/lmdb_dataset_ops.cc deleted file mode 100644 index 84a52bf31..000000000 --- a/tensorflow_io/lmdb/kernels/lmdb_dataset_ops.cc +++ /dev/null @@ -1,219 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "tensorflow/core/platform/file_system.h" - -#include "lmdb.h" // NOLINT(build/include) - -namespace tensorflow { -namespace data { -namespace { - -class LMDBDatasetOp : public DatasetOpKernel { - public: - using DatasetOpKernel::DatasetOpKernel; - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - const Tensor* filenames_tensor; - OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); - - std::vector filenames; - filenames.reserve(filenames_tensor->NumElements()); - for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); - } - - *output = new Dataset(ctx, filenames); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const std::vector& filenames) - : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::LMDB")}); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = - new DataTypeVector({DT_STRING, DT_STRING}); - return *dtypes; - } - - const std::vector& output_shapes() const override { - static std::vector* shapes = - new std::vector({{}, {}}); - return *shapes; - } - - string DebugString() const override { return "LMDBDatasetOp::Dataset"; } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - do { - if (mdb_cursor_) { - out_tensors->emplace_back(ctx->allocator({}), DT_STRING, - TensorShape({})); - Tensor& key_tensor = out_tensors->back(); - key_tensor.scalar()() = string( - static_cast(mdb_key_.mv_data), mdb_key_.mv_size); - - out_tensors->emplace_back(ctx->allocator({}), DT_STRING, - TensorShape({})); - Tensor& value_tensor = out_tensors->back(); - value_tensor.scalar()() = - string(static_cast(mdb_value_.mv_data), - mdb_value_.mv_size); - - int val; - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - ++current_file_index_; - } - *end_of_sequence = false; - return Status::OK(); - } - if (current_file_index_ == dataset()->filenames_.size()) { - *end_of_sequence = true; - return Status::OK(); - } - - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - } while (true); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); - } - - private: - Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (current_file_index_ >= dataset()->filenames_.size()) { - return errors::InvalidArgument( - "current_file_index_:", current_file_index_, - " >= filenames_.size():", dataset()->filenames_.size()); - } - const string& filename = dataset()->filenames_[current_file_index_]; - - int val = mdb_env_create(&mdb_env_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; - - struct stat source_stat; - if (stat(filename.c_str(), &source_stat) == 0 && - (source_stat.st_mode & S_IFREG)) { - flags |= MDB_NOSUBDIR; - } - val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - } - return Status::OK(); - } - void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (mdb_env_ != nullptr) { - if (mdb_cursor_) { - mdb_cursor_close(mdb_cursor_); - mdb_cursor_ = nullptr; - } - mdb_dbi_close(mdb_env_, mdb_dbi_); - mdb_txn_abort(mdb_txn_); - mdb_env_close(mdb_env_); - mdb_txn_ = nullptr; - mdb_dbi_ = 0; - mdb_env_ = nullptr; - } - } - mutex mu_; - size_t current_file_index_ GUARDED_BY(mu_) = 0; - MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr; - MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr; - MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0; - MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr; - - MDB_val mdb_key_ GUARDED_BY(mu_); - MDB_val mdb_value_ GUARDED_BY(mu_); - }; - - const std::vector filenames_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), - LMDBDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow_io/lmdb/kernels/lmdb_input.cc b/tensorflow_io/lmdb/kernels/lmdb_input.cc new file mode 100644 index 000000000..5aa1622bc --- /dev/null +++ b/tensorflow_io/lmdb/kernels/lmdb_input.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/dataset_ops.h" +#include "lmdb.h" + +namespace tensorflow { +namespace data { + +class LMDBInputStream{ +public: + explicit LMDBInputStream(const string& filename) + : mdb_status_(MDB_SUCCESS) + , mdb_env_(nullptr) + , mdb_txn_(nullptr) + , mdb_dbi_(0) + , mdb_cursor_(nullptr) { + mdb_status_ = mdb_env_create(&mdb_env_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; + + struct stat source_stat; + if (stat(filename.c_str(), &source_stat) == 0 && + (source_stat.st_mode & S_IFREG)) { + flags |= MDB_NOSUBDIR; + } + mdb_status_ = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); + if (mdb_status_ != MDB_SUCCESS) { + return; + } + mdb_status_ = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); + if (mdb_status_ != MDB_SUCCESS && mdb_status_ != MDB_NOTFOUND) { + return; + } + if (mdb_status_ == MDB_NOTFOUND) { + // empty data, move on. + } + return; + } + ~LMDBInputStream() { + if (mdb_env_ != nullptr) { + if (mdb_cursor_) { + mdb_cursor_close(mdb_cursor_); + mdb_cursor_ = nullptr; + } + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + mdb_env_close(mdb_env_); + mdb_txn_ = nullptr; + mdb_dbi_ = 0; + mdb_env_ = nullptr; + } + } + Status ReadRecord(string* key, string* val) { + if (mdb_status_ != MDB_SUCCESS) { + if (mdb_status_ == MDB_NOTFOUND) { + return errors::OutOfRange("EOF reached"); + } + return errors::InvalidArgument(mdb_strerror(mdb_status_)); + } + *key = std::move(string(static_cast(mdb_key_.mv_data), mdb_key_.mv_size)); + *val = std::move(string(static_cast(mdb_value_.mv_data), mdb_value_.mv_size)); + mdb_status_ = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); + return Status::OK(); + } +private: + int mdb_status_ = MDB_SUCCESS; + MDB_env* mdb_env_ = nullptr; + MDB_txn* mdb_txn_ = nullptr; + MDB_dbi mdb_dbi_ = 0; + + MDB_cursor* mdb_cursor_ = nullptr; + MDB_val mdb_key_; + MDB_val mdb_value_; +}; + +class LMDBInput: public FileInput { + public: + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + if (state.get() == nullptr) { + // LMDB does not support compression or any file system so + // using filename() (instead of stream/s) here is fine. + state.reset(new LMDBInputStream(filename())); + } + std::vector keys; + keys.reserve(record_to_read); + std::vector vals; + vals.reserve(record_to_read); + while ((*record_read) < record_to_read) { + string key, val; + Status status = state.get()->ReadRecord(&key, &val); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return status; + } + if (!status.ok()) { + break; + } + keys.emplace_back(std::move(key)); + vals.emplace_back(std::move(val)); + (*record_read)++; + } + if (*record_read > 0) { + Tensor key_tensor(ctx->allocator({}), DT_STRING, {*record_read}); + Tensor val_tensor(ctx->allocator({}), DT_STRING, {*record_read}); + for (size_t i = 0; i < (*record_read); i++) { + key_tensor.flat()(i) = std::move(keys[i]); + val_tensor.flat()(i) = std::move(vals[i]); + } + out_tensors->emplace_back(std::move(key_tensor)); + out_tensors->emplace_back(std::move(val_tensor)); + } + return Status::OK(); + } + Status FromStream(io::InputStreamInterface* s) override { + return Status::OK(); + } + void EncodeAttributes(VariantTensorData* data) const override { + } + bool DecodeAttributes(const VariantTensorData& data) override { + return true; + } + protected: +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(LMDBInput, "tensorflow::data::LMDBInput"); + +REGISTER_KERNEL_BUILDER(Name("LMDBInput").Device(DEVICE_CPU), + FileInputOp); +REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), + FileInputDatasetOp); +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/lmdb/ops/dataset_ops.cc b/tensorflow_io/lmdb/ops/lmdb_ops.cc similarity index 65% rename from tensorflow_io/lmdb/ops/dataset_ops.cc rename to tensorflow_io/lmdb/ops/lmdb_ops.cc index 5bcf5c99e..7585c09fd 100644 --- a/tensorflow_io/lmdb/ops/dataset_ops.cc +++ b/tensorflow_io/lmdb/ops/lmdb_ops.cc @@ -19,12 +19,27 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("LMDBInput") + .Input("source: string") + .Output("handle: variant") + .Attr("filters: list(string) = []") + .Attr("columns: list(string) = []") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + REGISTER_OP("LMDBDataset") - .Input("filenames: string") + .Input("input: T") + .Input("batch: int64") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("T: {string, variant} = DT_VARIANT") .SetIsStateful() - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({})); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/lmdb/python/ops/lmdb_dataset_ops.py b/tensorflow_io/lmdb/python/ops/lmdb_ops.py similarity index 79% rename from tensorflow_io/lmdb/python/ops/lmdb_dataset_ops.py rename to tensorflow_io/lmdb/python/ops/lmdb_ops.py index 3d480461b..470d891ca 100644 --- a/tensorflow_io/lmdb/python/ops/lmdb_dataset_ops.py +++ b/tensorflow_io/lmdb/python/ops/lmdb_ops.py @@ -26,7 +26,7 @@ class LMDBDataset(data.Dataset): """A LMDB Dataset that reads the lmdb file.""" - def __init__(self, filenames): + def __init__(self, filenames, batch=None): """Create a `LMDBDataset`. `LMDBDataset` allows a user to read data from a mdb file as @@ -43,8 +43,8 @@ def __init__(self, filenames): Args: filenames: A `tf.string` tensor containing one or more filenames. """ - self._filenames = tensorflow.convert_to_tensor( - filenames, dtype=dtypes.string, name="filenames") + self._data_input = lmdb_ops.lmdb_input(filenames) + self._batch = 0 if batch is None else batch super(LMDBDataset, self).__init__() def _inputs(self): @@ -52,17 +52,22 @@ def _inputs(self): def _as_variant_tensor(self): return lmdb_ops.lmdb_dataset( - self._filenames, - (dtypes.string, dtypes.string), - (tensorflow.TensorShape([]), tensorflow.TensorShape([]))) + self._data_input, + self._batch, + output_types=self.output_types, + output_shapes=self.output_shapes) @property - def output_classes(self): - return tensorflow.Tensor, tensorflow.Tensor + def output_shapes(self): + return ( + tensorflow.TensorShape([]), + tensorflow.TensorShape([])) if self._batch == 0 else ( + tensorflow.TensorShape([None]), + tensorflow.TensorShape([None])) @property - def output_shapes(self): - return (tensorflow.TensorShape([]), tensorflow.TensorShape([])) + def output_classes(self): + return tensorflow.Tensor, tensorflow.Tensor @property def output_types(self): diff --git a/tests/test_lmdb.py b/tests/test_lmdb.py index abdc745c6..58e6644e5 100644 --- a/tests/test_lmdb.py +++ b/tests/test_lmdb.py @@ -59,6 +59,39 @@ def test_read_from_file(self): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def test_read_from_file_with_batch(self): + """test_read_from_file""" + super(LMDBDatasetTest, self).setUp() + # Copy database out because we need the path to be writable to use locks. + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_lmdb", "data.mdb") + self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") + shutil.copy(path, self.db_path) + + filename = self.db_path + + dataset = lmdb_io.LMDBDataset([filename], batch=3) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.cached_session() as sess: + sess.run(init_op) + for i in range(0, 9, 3): + k = [ + str(i).encode(), + str(i + 1).encode(), + str(i + 2).encode()] + v = [ + str(chr(ord("a") + i)).encode(), + str(chr(ord("a") + i + 1)).encode(), + str(chr(ord("a") + i + 2)).encode()] + self.assertAllEqual((k, v), sess.run(get_next)) + self.assertAllEqual( + ([str(9).encode()], [str('j').encode()]), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main()