diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 572f526dc2e..7641139daed 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -115,6 +115,7 @@ Stereo Matching SceneFlowStereo SintelStereo InStereo2k + ETH3DStereo Image pairs ~~~~~~~~~~~ diff --git a/test/test_datasets.py b/test/test_datasets.py index 3a17758531c..ad31856cd01 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2872,15 +2872,26 @@ def inject_fake_data(self, tmpdir, config): os.makedirs(fallingthings_dir, exist_ok=True) num_examples = {"single": 2, "mixed": 3, "both": 4}.get(config["variant"], 0) + variants = { "single": ["single"], "mixed": ["mixed"], "both": ["single", "mixed"], }.get(config["variant"], []) + variant_dir_prefixes = { + "single": 1, + "mixed": 0, + } + for variant_name in variants: variant_dir = pathlib.Path(fallingthings_dir) / variant_name os.makedirs(variant_dir, exist_ok=True) + + for i in range(variant_dir_prefixes[variant_name]): + variant_dir = variant_dir / f"{i:02d}" + os.makedirs(variant_dir, exist_ok=True) + for i in range(num_examples): self._make_scene_folder( root=variant_dir, @@ -3109,5 +3120,72 @@ def test_bad_input(self): pass +class ETH3DStereoestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.ETH3DStereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + @staticmethod + def _create_scene_folder(num_examples: int, root_dir: str): + # make the root_dir if it does not exits + root_dir = pathlib.Path(root_dir) + os.makedirs(root_dir, exist_ok=True) + + for i in range(num_examples): + scene_dir = root_dir / f"scene_{i}" + os.makedirs(scene_dir, exist_ok=True) + # populate with left right images + datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100)) + + @staticmethod + def _create_annotation_folder(num_examples: int, root_dir: str): + # make the root_dir if it does not exits + root_dir = pathlib.Path(root_dir) + os.makedirs(root_dir, exist_ok=True) + + # create scene directories + for i in range(num_examples): + scene_dir = root_dir / f"scene_{i}" + os.makedirs(scene_dir, exist_ok=True) + # populate with a random png file for occlusion mask, and a pfm file for disparity + datasets_utils.create_image_file(root=scene_dir, name="mask0nocc.png", size=(1, 100, 100)) + + pfm_path = scene_dir / "disp0GT.pfm" + datasets_utils.make_fake_pfm_file(h=100, w=100, file_name=pfm_path) + + def inject_fake_data(self, tmpdir, config): + eth3d_dir = pathlib.Path(tmpdir) / "ETH3D" + + num_examples = 2 if config["split"] == "train" else 3 + + split_name = "two_view_training" if config["split"] == "train" else "two_view_test" + split_dir = eth3d_dir / split_name + self._create_scene_folder(num_examples, split_dir) + + if config["split"] == "train": + annot_dir = eth3d_dir / "two_view_training_gt" + self._create_annotation_folder(num_examples, annot_dir) + + return num_examples + + def test_training_splits(self): + with self.create_dataset(split="train") as (dataset, _): + for left, right, disparity, valid_mask in dataset: + datasets_utils.shape_test_for_stereo(left, right, disparity, valid_mask) + + def test_testing_splits(self): + with self.create_dataset(split="test") as (dataset, _): + assert all(d == (None, None) for d in dataset._disparities) + for left, right, disparity, valid_mask in dataset: + assert valid_mask is None + datasets_utils.shape_test_for_stereo(left, right, disparity) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 68b8ec61d6f..d5303849a41 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,6 +1,7 @@ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel from ._stereo_matching import ( CarlaStereo, + ETH3DStereo, FallingThingsStereo, InStereo2k, Kitti2012Stereo, @@ -121,4 +122,5 @@ "SceneFlowStereo", "SintelStereo", "InStereo2k", + "ETH3DStereo", ) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index d9dbd18a541..3938af68c7b 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -371,19 +371,20 @@ class FallingThingsStereo(StereoMatchingDataset): root FallingThings single - scene1 - _object_settings.json - _camera_settings.json - image1.left.depth.png - image1.right.depth.png - image1.left.jpg - image1.right.jpg - image2.left.depth.png - image2.right.depth.png - image2.left.jpg - image2.right - ... - scene2 + dir1 + scene1 + _object_settings.json + _camera_settings.json + image1.left.depth.png + image1.right.depth.png + image1.left.jpg + image1.right.jpg + image2.left.depth.png + image2.right.depth.png + image2.left.jpg + image2.right + ... + scene2 ... mixed scene1 @@ -420,13 +421,18 @@ def __init__(self, root: str, variant: str = "single", transforms: Optional[Call "both": ["single", "mixed"], }[variant] + split_prefix = { + "single": Path("*") / "*", + "mixed": Path("*"), + } + for s in variants: - left_img_pattern = str(root / s / "*" / "*.left.jpg") - right_img_pattern = str(root / s / "*" / "*.right.jpg") + left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg") + right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg") self._images += self._scan_pairs(left_img_pattern, right_img_pattern) - left_disparity_pattern = str(root / s / "*" / "*.left.depth.png") - right_disparity_pattern = str(root / s / "*" / "*.right.depth.png") + left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png") + right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png") self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) def _read_disparity(self, file_path: str) -> Tuple: @@ -762,3 +768,103 @@ def __getitem__(self, index: int) -> Tuple: a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. """ return super().__getitem__(index) + + +class ETH3DStereo(StereoMatchingDataset): + """ETH3D `Low-Res Two-View `_ dataset. + + The dataset is expected to have the following structure: :: + + root + ETH3D + two_view_training + scene1 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + scene2 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + ... + two_view_training_gt + scene1 + disp0GT.pfm + mask0nocc.png + scene2 + disp0GT.pfm + mask0nocc.png + ... + two_view_testing + scene1 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + scene2 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + ... + + Args: + root (string): Root directory of the ETH3D Dataset. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "ETH3D" + + img_dir = "two_view_training" if split == "train" else "two_view_test" + anot_dir = "two_view_training_gt" + + left_img_pattern = str(root / img_dir / "*" / "im0.png") + right_img_pattern = str(root / img_dir / "*" / "im1.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "test": + self._disparities = list((None, None) for _ in self._images) + else: + disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm") + self._disparities = self._scan_pairs(disparity_pattern, None) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if file_path is None: + return None, None + + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + mask_path = Path(file_path).parent / "mask0nocc.png" + valid_mask = Image.open(mask_path) + valid_mask = np.asarray(valid_mask).astype(bool) + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index)