From d7cdd9598532721f2fe060d414d8a2e999709dcf Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 4 Mar 2021 12:58:47 +0100 Subject: [PATCH] [FBcode->GH] Add back cifar_root --- test/fakedata_generation.py | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index 4249dedd54e..dac415df110 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -88,6 +88,61 @@ def _make_label_file(filename, num_images): yield tmp_dir +@contextlib.contextmanager +def cifar_root(version): + def _get_version_params(version): + if version == 'CIFAR10': + return { + 'base_folder': 'cifar-10-batches-py', + 'train_files': ['data_batch_{}'.format(batch) for batch in range(1, 6)], + 'test_file': 'test_batch', + 'target_key': 'labels', + 'meta_file': 'batches.meta', + 'classes_key': 'label_names', + } + elif version == 'CIFAR100': + return { + 'base_folder': 'cifar-100-python', + 'train_files': ['train'], + 'test_file': 'test', + 'target_key': 'fine_labels', + 'meta_file': 'meta', + 'classes_key': 'fine_label_names', + } + else: + raise ValueError + + def _make_pickled_file(obj, file): + with open(file, 'wb') as fh: + pickle.dump(obj, fh, 2) + + def _make_data_file(file, target_key): + obj = { + 'data': np.zeros((1, 32 * 32 * 3), dtype=np.uint8), + target_key: [0] + } + _make_pickled_file(obj, file) + + def _make_meta_file(file, classes_key): + obj = { + classes_key: ['fakedata'], + } + _make_pickled_file(obj, file) + + params = _get_version_params(version) + with get_tmp_dir() as root: + base_folder = os.path.join(root, params['base_folder']) + os.mkdir(base_folder) + + for file in list(params['train_files']) + [params['test_file']]: + _make_data_file(os.path.join(base_folder, file), params['target_key']) + + _make_meta_file(os.path.join(base_folder, params['meta_file']), + params['classes_key']) + + yield root + + @contextlib.contextmanager def imagenet_root(): import scipy.io as sio