Skip to content

Commit 668927e

Browse files
authored
[FBcode->GH] Add back cifar_root (#3503)
1 parent d4d36e6 commit 668927e

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

test/fakedata_generation.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,61 @@ def _make_label_file(filename, num_images):
8888
yield tmp_dir
8989

9090

91+
@contextlib.contextmanager
92+
def cifar_root(version):
93+
def _get_version_params(version):
94+
if version == 'CIFAR10':
95+
return {
96+
'base_folder': 'cifar-10-batches-py',
97+
'train_files': ['data_batch_{}'.format(batch) for batch in range(1, 6)],
98+
'test_file': 'test_batch',
99+
'target_key': 'labels',
100+
'meta_file': 'batches.meta',
101+
'classes_key': 'label_names',
102+
}
103+
elif version == 'CIFAR100':
104+
return {
105+
'base_folder': 'cifar-100-python',
106+
'train_files': ['train'],
107+
'test_file': 'test',
108+
'target_key': 'fine_labels',
109+
'meta_file': 'meta',
110+
'classes_key': 'fine_label_names',
111+
}
112+
else:
113+
raise ValueError
114+
115+
def _make_pickled_file(obj, file):
116+
with open(file, 'wb') as fh:
117+
pickle.dump(obj, fh, 2)
118+
119+
def _make_data_file(file, target_key):
120+
obj = {
121+
'data': np.zeros((1, 32 * 32 * 3), dtype=np.uint8),
122+
target_key: [0]
123+
}
124+
_make_pickled_file(obj, file)
125+
126+
def _make_meta_file(file, classes_key):
127+
obj = {
128+
classes_key: ['fakedata'],
129+
}
130+
_make_pickled_file(obj, file)
131+
132+
params = _get_version_params(version)
133+
with get_tmp_dir() as root:
134+
base_folder = os.path.join(root, params['base_folder'])
135+
os.mkdir(base_folder)
136+
137+
for file in list(params['train_files']) + [params['test_file']]:
138+
_make_data_file(os.path.join(base_folder, file), params['target_key'])
139+
140+
_make_meta_file(os.path.join(base_folder, params['meta_file']),
141+
params['classes_key'])
142+
143+
yield root
144+
145+
91146
@contextlib.contextmanager
92147
def imagenet_root():
93148
import scipy.io as sio

0 commit comments

Comments
 (0)