@@ -59,6 +59,38 @@ def test_read_from_file(self):
5959 with self .assertRaises (errors .OutOfRangeError ):
6060 sess .run (get_next )
6161
62+ def test_read_from_file_with_batch (self ):
63+ """test_read_from_file"""
64+ super (LMDBDatasetTest , self ).setUp ()
65+ # Copy database out because we need the path to be writable to use locks.
66+ path = os .path .join (
67+ os .path .dirname (os .path .abspath (__file__ )), "test_lmdb" , "data.mdb" )
68+ self .db_path = os .path .join (self .get_temp_dir (), "data.mdb" )
69+ shutil .copy (path , self .db_path )
70+
71+ filename = self .db_path
72+
73+ dataset = lmdb_io .LMDBDataset ([filename ], batch = 3 )
74+ iterator = dataset .make_initializable_iterator ()
75+ init_op = iterator .initializer
76+ get_next = iterator .get_next ()
77+
78+ with self .cached_session () as sess :
79+ sess .run (init_op )
80+ for i in range (0 , 9 , 3 ):
81+ k = [
82+ str (i ).encode (),
83+ str (i + 1 ).encode (),
84+ str (i + 2 ).encode ()]
85+ v = [
86+ str (chr (ord ("a" ) + i )).encode (),
87+ str (chr (ord ("a" ) + i + 1 )).encode (),
88+ str (chr (ord ("a" ) + i + 2 )).encode ()]
89+ self .assertAllEqual ((k , v ), sess .run (get_next ))
90+ self .assertAllEqual ((['9' ], ['j' ]), sess .run (get_next ))
91+ with self .assertRaises (errors .OutOfRangeError ):
92+ sess .run (get_next )
93+
6294
6395if __name__ == "__main__" :
6496 test .main ()
0 commit comments