1212
1313
1414class LSUNClass (data .Dataset ):
15- def __init__ (self , db_path , transform = None , target_transform = None ):
15+ def __init__ (self , root , transform = None , target_transform = None ):
1616 import lmdb
17- self .db_path = db_path
18- self .env = lmdb .open (db_path , max_readers = 1 , readonly = True , lock = False ,
17+ self .root = os .path .expanduser (root )
18+ self .transform = transform
19+ self .target_transform = target_transform
20+
21+ self .env = lmdb .open (root , max_readers = 1 , readonly = True , lock = False ,
1922 readahead = False , meminit = False )
2023 with self .env .begin (write = False ) as txn :
2124 self .length = txn .stat ()['entries' ]
22- cache_file = '_cache_' + db_path .replace ('/' , '_' )
25+ cache_file = '_cache_' + root .replace ('/' , '_' )
2326 if os .path .isfile (cache_file ):
2427 self .keys = pickle .load (open (cache_file , "rb" ))
2528 else :
2629 with self .env .begin (write = False ) as txn :
2730 self .keys = [key for key , _ in txn .cursor ()]
2831 pickle .dump (self .keys , open (cache_file , "wb" ))
29- self .transform = transform
30- self .target_transform = target_transform
3132
3233 def __getitem__ (self , index ):
3334 img , target = None , None
@@ -60,7 +61,7 @@ class LSUN(data.Dataset):
6061 `LSUN <http://lsun.cs.princeton.edu>`_ dataset.
6162
6263 Args:
63- db_path (string): Root directory for the database files.
64+ root (string): Root directory for the database files.
6465 classes (string or list): One of {'train', 'val', 'test'} or a list of
6566 categories to load. e,g. ['bedroom_train', 'church_train'].
6667 transform (callable, optional): A function/transform that takes in an PIL image
@@ -69,13 +70,16 @@ class LSUN(data.Dataset):
6970 target and transforms it.
7071 """
7172
72- def __init__ (self , db_path , classes = 'train' ,
73+ def __init__ (self , root , classes = 'train' ,
7374 transform = None , target_transform = None ):
7475 categories = ['bedroom' , 'bridge' , 'church_outdoor' , 'classroom' ,
7576 'conference_room' , 'dining_room' , 'kitchen' ,
7677 'living_room' , 'restaurant' , 'tower' ]
7778 dset_opts = ['train' , 'val' , 'test' ]
78- self .db_path = db_path
79+ self .root = os .path .expanduser (root )
80+ self .transform = transform
81+ self .target_transform = target_transform
82+
7983 if type (classes ) == str and classes in dset_opts :
8084 if classes == 'test' :
8185 classes = [classes ]
@@ -102,7 +106,7 @@ def __init__(self, db_path, classes='train',
102106 self .dbs = []
103107 for c in self .classes :
104108 self .dbs .append (LSUNClass (
105- db_path = db_path + '/' + c + '_lmdb' ,
109+ root = root + '/' + c + '_lmdb' ,
106110 transform = transform ))
107111
108112 self .indices = []
@@ -112,7 +116,6 @@ def __init__(self, db_path, classes='train',
112116 self .indices .append (count )
113117
114118 self .length = count
115- self .target_transform = target_transform
116119
117120 def __getitem__ (self , index ):
118121 """
@@ -146,6 +149,7 @@ def __repr__(self):
146149 fmt_str = 'Dataset ' + self .__class__ .__name__ + '\n '
147150 fmt_str += ' Number of datapoints: {}\n ' .format (self .__len__ ())
148151 fmt_str += ' Root Location: {}\n ' .format (self .root )
152+ fmt_str += ' Classes: {}\n ' .format (self .classes )
149153 tmp = ' Transforms (if any): '
150154 fmt_str += '{0}{1}\n ' .format (tmp , self .transform .__repr__ ().replace ('\n ' , '\n ' + ' ' * len (tmp )))
151155 tmp = ' Target Transforms (if any): '
0 commit comments