diff --git a/test/test_datasets.py b/test/test_datasets.py index 096dff97217..b36acfda6f3 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -23,6 +23,8 @@ import shutil import json import random +import string +import io try: @@ -954,5 +956,85 @@ def _create_annotation_file(self, root, name, video_files): fh.writelines(f"{file}\n" for file in sorted(video_files)) +class LSUNTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.LSUN + + REQUIRED_PACKAGES = ("lmdb",) + CONFIGS = datasets_utils.combinations_grid( + classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"]) + ) + + _CATEGORIES = ( + "bedroom", + "bridge", + "church_outdoor", + "classroom", + "conference_room", + "dining_room", + "kitchen", + "living_room", + "restaurant", + "tower", + ) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) + + num_images = 0 + for cls in self._parse_classes(config["classes"]): + num_images += self._create_lmdb(root, cls) + + return num_images + + @contextlib.contextmanager + def create_dataset( + self, + *args, **kwargs + ): + with super().create_dataset(*args, **kwargs) as output: + yield output + # Currently datasets.LSUN caches the keys in the current directory rather than in the root directory. Thus, + # this creates a number of unique _cache_* files in the current directory that will not be removed together + # with the temporary directory + for file in os.listdir(os.getcwd()): + if file.startswith("_cache_"): + os.remove(file) + + def _parse_classes(self, classes): + if not isinstance(classes, str): + return classes + + split = classes + if split == "test": + return [split] + + return [f"{category}_{split}" for category in self._CATEGORIES] + + def _create_lmdb(self, root, cls): + lmdb = datasets_utils.lazy_importer.lmdb + hexdigits_lowercase = string.digits + string.ascii_lowercase[:6] + + folder = f"{cls}_lmdb" + + num_images = torch.randint(1, 4, size=()).item() + format = "webp" + files = datasets_utils.create_image_folder(root, folder, lambda idx: f"{idx}.{format}", num_images) + + with lmdb.open(str(root / folder)) as env, env.begin(write=True) as txn: + for file in files: + key = "".join(random.choice(hexdigits_lowercase) for _ in range(40)).encode() + + buffer = io.BytesIO() + Image.open(file).save(buffer, format) + buffer.seek(0) + value = buffer.read() + + txn.put(key, value) + + os.remove(file) + + return num_images + + if __name__ == "__main__": unittest.main()