@@ -88,6 +88,61 @@ def _make_label_file(filename, num_images):
8888 yield tmp_dir
8989
9090
91+ @contextlib .contextmanager
92+ def cifar_root (version ):
93+ def _get_version_params (version ):
94+ if version == 'CIFAR10' :
95+ return {
96+ 'base_folder' : 'cifar-10-batches-py' ,
97+ 'train_files' : ['data_batch_{}' .format (batch ) for batch in range (1 , 6 )],
98+ 'test_file' : 'test_batch' ,
99+ 'target_key' : 'labels' ,
100+ 'meta_file' : 'batches.meta' ,
101+ 'classes_key' : 'label_names' ,
102+ }
103+ elif version == 'CIFAR100' :
104+ return {
105+ 'base_folder' : 'cifar-100-python' ,
106+ 'train_files' : ['train' ],
107+ 'test_file' : 'test' ,
108+ 'target_key' : 'fine_labels' ,
109+ 'meta_file' : 'meta' ,
110+ 'classes_key' : 'fine_label_names' ,
111+ }
112+ else :
113+ raise ValueError
114+
115+ def _make_pickled_file (obj , file ):
116+ with open (file , 'wb' ) as fh :
117+ pickle .dump (obj , fh , 2 )
118+
119+ def _make_data_file (file , target_key ):
120+ obj = {
121+ 'data' : np .zeros ((1 , 32 * 32 * 3 ), dtype = np .uint8 ),
122+ target_key : [0 ]
123+ }
124+ _make_pickled_file (obj , file )
125+
126+ def _make_meta_file (file , classes_key ):
127+ obj = {
128+ classes_key : ['fakedata' ],
129+ }
130+ _make_pickled_file (obj , file )
131+
132+ params = _get_version_params (version )
133+ with get_tmp_dir () as root :
134+ base_folder = os .path .join (root , params ['base_folder' ])
135+ os .mkdir (base_folder )
136+
137+ for file in list (params ['train_files' ]) + [params ['test_file' ]]:
138+ _make_data_file (os .path .join (base_folder , file ), params ['target_key' ])
139+
140+ _make_meta_file (os .path .join (base_folder , params ['meta_file' ]),
141+ params ['classes_key' ])
142+
143+ yield root
144+
145+
91146@contextlib .contextmanager
92147def imagenet_root ():
93148 import scipy .io as sio
0 commit comments