From c38f688cded3f198050a0d8c94b3e1f2e04bc452 Mon Sep 17 00:00:00 2001 From: Ponku Date: Tue, 2 Aug 2022 16:10:57 +0100 Subject: [PATCH] Added CREStereo dataset --- docs/source/datasets.rst | 1 + test/test_datasets.py | 31 +++++++++ torchvision/datasets/__init__.py | 3 +- torchvision/datasets/_stereo_matching.py | 88 ++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 1 deletion(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index af7ac072e31..fe9c67072c2 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -111,6 +111,7 @@ Stereo Matching CarlaStereo Kitti2012Stereo Kitti2015Stereo + CREStereo Image pairs ~~~~~~~~~~~ diff --git a/test/test_datasets.py b/test/test_datasets.py index 54696b0d6a8..971664c5a39 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2841,5 +2841,36 @@ def test_train_splits(self): datasets_utils.shape_test_for_stereo(left, right, disparity) +class CREStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CREStereo + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, np.ndarray, type(None)) + + def inject_fake_data(self, tmpdir, config): + crestereo_dir = pathlib.Path(tmpdir) / "CREStereo" + os.makedirs(crestereo_dir, exist_ok=True) + + examples = {"tree": 2, "shapenet": 3, "reflective": 6, "hole": 5} + + for category_name in ["shapenet", "reflective", "tree", "hole"]: + split_dir = crestereo_dir / category_name + os.makedirs(split_dir, exist_ok=True) + num_examples = examples[category_name] + + for idx in range(num_examples): + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.jpg", size=(100, 100)) + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.jpg", size=(100, 100)) + # these are going to end up being gray scale images + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.disp.png", size=(1, 100, 100)) + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.disp.png", size=(1, 100, 100)) + + return sum(examples.values()) + + def test_splits(self): + with self.create_dataset() as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo(left, right, disparity) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index d8b6293fb42..4bde5b84405 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,5 +1,5 @@ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel -from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo +from ._stereo_matching import CarlaStereo, CREStereo, Kitti2012Stereo, Kitti2015Stereo from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -109,4 +109,5 @@ "Kitti2012Stereo", "Kitti2015Stereo", "CarlaStereo", + "CREStereo", ) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index de213fc0368..dfc0e813ecc 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -359,3 +359,91 @@ def __getitem__(self, index: int) -> Tuple: Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. """ return super().__getitem__(index) + + +class CREStereo(StereoMatchingDataset): + """Synthetic dataset used in training the `CREStereo `_ architecture. + Dataset details on the official paper `repo `_. + + The dataset is expected to have the following structure: :: + + root + CREStereo + tree + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + img2_left.jpg + img2_right.jpg + img2_left.disp.jpg + img2_right.disp.jpg + ... + shapenet + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + reflective + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + hole + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + + Args: + root (str): Root directory of the dataset. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + ): + super().__init__(root, transforms) + + root = Path(root) / "CREStereo" + + dirs = ["shapenet", "reflective", "tree", "hole"] + + for s in dirs: + left_image_pattern = str(root / s / "*_left.jpg") + right_image_pattern = str(root / s / "*_right.jpg") + imgs = self._scan_pairs(left_image_pattern, right_image_pattern) + self._images += imgs + + left_disparity_pattern = str(root / s / "*_left.disp.png") + right_disparity_pattern = str(root / s / "*_right.disp.png") + disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities += disparities + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + """ + return super().__getitem__(index)