Skip to content

Commit dcd4ed9

Browse files
committed
Add test case for LMDB with batch
Signed-off-by: Yong Tang <[email protected]>
1 parent 29ebda9 commit dcd4ed9

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/test_lmdb.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,38 @@ def test_read_from_file(self):
5959
with self.assertRaises(errors.OutOfRangeError):
6060
sess.run(get_next)
6161

62+
def test_read_from_file_with_batch(self):
63+
"""test_read_from_file"""
64+
super(LMDBDatasetTest, self).setUp()
65+
# Copy database out because we need the path to be writable to use locks.
66+
path = os.path.join(
67+
os.path.dirname(os.path.abspath(__file__)), "test_lmdb", "data.mdb")
68+
self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
69+
shutil.copy(path, self.db_path)
70+
71+
filename = self.db_path
72+
73+
dataset = lmdb_io.LMDBDataset([filename], batch=3)
74+
iterator = dataset.make_initializable_iterator()
75+
init_op = iterator.initializer
76+
get_next = iterator.get_next()
77+
78+
with self.cached_session() as sess:
79+
sess.run(init_op)
80+
for i in range(0, 9, 3):
81+
k = [
82+
str(i).encode(),
83+
str(i + 1).encode(),
84+
str(i + 2).encode()]
85+
v = [
86+
str(chr(ord("a") + i)).encode(),
87+
str(chr(ord("a") + i + 1)).encode(),
88+
str(chr(ord("a") + i + 2)).encode()]
89+
self.assertAllEqual((k, v), sess.run(get_next))
90+
self.assertAllEqual((['9'], ['j']), sess.run(get_next))
91+
with self.assertRaises(errors.OutOfRangeError):
92+
sess.run(get_next)
93+
6294

6395
if __name__ == "__main__":
6496
test.main()

0 commit comments

Comments
 (0)