diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 7641139daed..d1346bb4bdd 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -111,6 +111,7 @@ Stereo Matching CarlaStereo Kitti2012Stereo Kitti2015Stereo + CREStereo FallingThingsStereo SceneFlowStereo SintelStereo diff --git a/test/test_datasets.py b/test/test_datasets.py index ad31856cd01..b5ca24ab9c9 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2841,6 +2841,37 @@ def test_train_splits(self): datasets_utils.shape_test_for_stereo(left, right, disparity) +class CREStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CREStereo + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, np.ndarray, type(None)) + + def inject_fake_data(self, tmpdir, config): + crestereo_dir = pathlib.Path(tmpdir) / "CREStereo" + os.makedirs(crestereo_dir, exist_ok=True) + + examples = {"tree": 2, "shapenet": 3, "reflective": 6, "hole": 5} + + for category_name in ["shapenet", "reflective", "tree", "hole"]: + split_dir = crestereo_dir / category_name + os.makedirs(split_dir, exist_ok=True) + num_examples = examples[category_name] + + for idx in range(num_examples): + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.jpg", size=(100, 100)) + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.jpg", size=(100, 100)) + # these are going to end up being gray scale images + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.disp.png", size=(1, 100, 100)) + datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.disp.png", size=(1, 100, 100)) + + return sum(examples.values()) + + def test_splits(self): + with self.create_dataset() as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo(left, right, disparity) + + class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.FallingThingsStereo ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(variant=("single", "mixed", "both")) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index d5303849a41..e809bbf1695 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,6 +1,7 @@ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel from ._stereo_matching import ( CarlaStereo, + CREStereo, ETH3DStereo, FallingThingsStereo, InStereo2k, @@ -118,6 +119,7 @@ "Kitti2012Stereo", "Kitti2015Stereo", "CarlaStereo", + "CREStereo", "FallingThingsStereo", "SceneFlowStereo", "SintelStereo", diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index 3938af68c7b..14fe1b60f44 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -363,6 +363,94 @@ def __getitem__(self, index: int) -> Tuple: return super().__getitem__(index) +class CREStereo(StereoMatchingDataset): + """Synthetic dataset used in training the `CREStereo `_ architecture. + Dataset details on the official paper `repo `_. + + The dataset is expected to have the following structure: :: + + root + CREStereo + tree + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + img2_left.jpg + img2_right.jpg + img2_left.disp.jpg + img2_right.disp.jpg + ... + shapenet + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + reflective + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + hole + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + + Args: + root (str): Root directory of the dataset. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + ): + super().__init__(root, transforms) + + root = Path(root) / "CREStereo" + + dirs = ["shapenet", "reflective", "tree", "hole"] + + for s in dirs: + left_image_pattern = str(root / s / "*_left.jpg") + right_image_pattern = str(root / s / "*_right.jpg") + imgs = self._scan_pairs(left_image_pattern, right_image_pattern) + self._images += imgs + + left_disparity_pattern = str(root / s / "*_left.disp.png") + right_disparity_pattern = str(root / s / "*_right.disp.png") + disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities += disparities + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] / 256.0 + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + """ + return super().__getitem__(index) + + class FallingThingsStereo(StereoMatchingDataset): """`FallingThings `_ dataset.