diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index bc0864083e0..af7ac072e31 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -101,6 +101,17 @@ Optical Flow KittiFlow Sintel +Stereo Matching +~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated/ + :template: class_dataset.rst + + CarlaStereo + Kitti2012Stereo + Kitti2015Stereo + Image pairs ~~~~~~~~~~~ diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 2043caae0a2..c232e7132b4 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -16,6 +16,8 @@ from collections import defaultdict from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union +import numpy as np + import PIL import PIL.Image import pytest @@ -23,6 +25,7 @@ import torchvision.datasets import torchvision.io from common_utils import disable_console_output, get_tmp_dir +from torchvision.transforms.functional import get_dimensions __all__ = [ @@ -748,6 +751,33 @@ def size(idx: int) -> Tuple[int, int, int]: ] +def shape_test_for_stereo( + left: PIL.Image.Image, + right: PIL.Image.Image, + disparity: Optional[np.ndarray] = None, + valid_mask: Optional[np.ndarray] = None, +): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, h, w = left_dims + # check that left and right are the same size + assert left_dims == right_dims + assert c == 3 + + # check that the disparity has the same spatial dimensions + # as the input + if disparity is not None: + assert disparity.ndim == 3 + assert disparity.shape == (1, h, w) + + if valid_mask is not None: + # check that valid mask is the same size as the disparity + _, dh, dw = disparity.shape + mh, mw = valid_mask.shape + assert dh == mh + assert dw == mw + + @requires_lazy_imports("av") def create_video_file( root: Union[pathlib.Path, str], diff --git a/test/test_datasets.py b/test/test_datasets.py index a108479aee3..54696b0d6a8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -13,6 +13,7 @@ import unittest import xml.etree.ElementTree as ET import zipfile +from typing import Union import datasets_utils import numpy as np @@ -2671,5 +2672,174 @@ def inject_fake_data(self, tmpdir: str, config): return len(sampled_classes) * num_images_per_class[config["split"]] +class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti2012Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = pathlib.Path(tmpdir) / "Kitti2012" + os.makedirs(kitti_dir, exist_ok=True) + + split_dir = kitti_dir / (config["split"] + "ing") + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 4, "test": 3}.get(config["split"], 0) + + datasets_utils.create_image_folder( + root=split_dir, + name="colored_0", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + datasets_utils.create_image_folder( + root=split_dir, + name="colored_1", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + + if config["split"] == "train": + datasets_utils.create_image_folder( + root=split_dir, + name="disp_noc", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2012 uses a single channel image for disparities + size=(1, 100, 200), + ) + + return num_examples + + def test_train_splits(self): + for split in ["train"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + assert disparity is None + datasets_utils.shape_test_for_stereo(left, right) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti2015Stereo + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = pathlib.Path(tmpdir) / "Kitti2015" + os.makedirs(kitti_dir, exist_ok=True) + + split_dir = kitti_dir / (config["split"] + "ing") + os.makedirs(split_dir, exist_ok=True) + + num_examples = {"train": 4, "test": 6}.get(config["split"], 0) + + datasets_utils.create_image_folder( + root=split_dir, + name="image_2", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + datasets_utils.create_image_folder( + root=split_dir, + name="image_3", + file_name_fn=lambda i: f"{i:06d}_10.png", + num_examples=num_examples, + size=(3, 100, 200), + ) + + if config["split"] == "train": + datasets_utils.create_image_folder( + root=split_dir, + name="disp_occ_0", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2015 uses a single channel image for disparities + size=(1, 100, 200), + ) + + datasets_utils.create_image_folder( + root=split_dir, + name="disp_occ_1", + file_name_fn=lambda i: f"{i:06d}.png", + num_examples=num_examples, + # Kitti2015 uses a single channel image for disparities + size=(1, 100, 200), + ) + + return num_examples + + def test_train_splits(self): + for split in ["train"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + datasets_utils.shape_test_for_stereo(left, right, disparity) + + def test_test_split(self): + for split in ["test"]: + with self.create_dataset(split=split) as (dataset, _): + for left, right, disparity, mask in dataset: + assert mask is None + assert disparity is None + datasets_utils.shape_test_for_stereo(left, right) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CarlaStereo + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, None)) + + @staticmethod + def _create_scene_folders(num_examples: int, root_dir: Union[str, pathlib.Path]): + # make the root_dir if it does not exits + os.makedirs(root_dir, exist_ok=True) + + for i in range(num_examples): + scene_dir = pathlib.Path(root_dir) / f"scene_{i}" + os.makedirs(scene_dir, exist_ok=True) + # populate with left right images + datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100)) + datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100)) + datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp0GT.pfm")) + datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp1GT.pfm")) + + def inject_fake_data(self, tmpdir, config): + carla_dir = pathlib.Path(tmpdir) / "carla-highres" + os.makedirs(carla_dir, exist_ok=True) + + split_dir = pathlib.Path(carla_dir) / "trainingF" + os.makedirs(split_dir, exist_ok=True) + + num_examples = 6 + self._create_scene_folders(num_examples=num_examples, root_dir=split_dir) + + return num_examples + + def test_train_splits(self): + with self.create_dataset() as (dataset, _): + for left, right, disparity in dataset: + datasets_utils.shape_test_for_stereo(left, right, disparity) + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 099d10da35d..d8b6293fb42 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,5 @@ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel +from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -105,4 +106,7 @@ "FGVCAircraft", "EuroSAT", "RenderedSST2", + "Kitti2012Stereo", + "Kitti2015Stereo", + "CarlaStereo", ) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py new file mode 100644 index 00000000000..de213fc0368 --- /dev/null +++ b/torchvision/datasets/_stereo_matching.py @@ -0,0 +1,361 @@ +import functools +from abc import ABC, abstractmethod +from glob import glob +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from .utils import _read_pfm, verify_str_arg +from .vision import VisionDataset + +__all__ = () + +_read_pfm_file = functools.partial(_read_pfm, slice_channels=1) + + +class StereoMatchingDataset(ABC, VisionDataset): + """Base interface for Stereo matching datasets""" + + _has_built_in_disparity_mask = False + + def __init__(self, root: str, transforms: Optional[Callable] = None): + """ + Args: + root(str): Root directory of the dataset. + transforms(callable, optional): A function/transform that takes in Tuples of + (images, disparities, valid_masks) and returns a transformed version of each of them. + images is a Tuple of (``PIL.Image``, ``PIL.Image``) + disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W) + valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W) + In some cases, when a dataset does not provide disparities, the ``disparities`` and + ``valid_masks`` can be Tuples containing None values. + For training splits generally the datasets provide a minimal guarantee of + images: (``PIL.Image``, ``PIL.Image``) + disparities: (``np.ndarray``, ``None``) with shape (1, H, W) + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W) + For some test splits, the datasets provides outputs that look like: + imgaes: (``PIL.Image``, ``PIL.Image``) + disparities: (``None``, ``None``) + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``None``, ``None``) + """ + super().__init__(root=root) + self.transforms = transforms + + self._images = [] # type: ignore + self._disparities = [] # type: ignore + + def _read_img(self, file_path: str) -> Image.Image: + img = Image.open(file_path) + if img.mode != "RGB": + img = img.convert("RGB") + return img + + def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None): + + left_paths = list(sorted(glob(paths_left_pattern))) + + right_paths: List[Union[None, str]] + if paths_right_pattern: + right_paths = list(sorted(glob(paths_right_pattern))) + else: + right_paths = list(None for _ in left_paths) + + if not left_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}") + + if not right_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}") + + if len(left_paths) != len(right_paths): + raise ValueError( + f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n " + f"left pattern: {paths_left_pattern}\n" + f"right pattern: {paths_right_pattern}\n" + ) + + paths = list((left, right) for left, right in zip(left_paths, right_paths)) + return paths + + @abstractmethod + def _read_disparity(self, file_path: str) -> Tuple: + # function that returns a disparity map and an occlusion map + pass + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask`` + can be a numpy boolean mask of shape (H, W) if the dataset provides a file + indicating which disparity pixels are valid. The disparity is a numpy array of + shape (1, H, W) and the images are PIL images. ``disparity`` is None for + datasets on which for ``split="test"`` the authors did not provide annotations. + """ + img_left = self._read_img(self._images[index][0]) + img_right = self._read_img(self._images[index][1]) + + dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0]) + dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1]) + + imgs = (img_left, img_right) + dsp_maps = (dsp_map_left, dsp_map_right) + valid_masks = (valid_mask_left, valid_mask_right) + + if self.transforms is not None: + ( + imgs, + dsp_maps, + valid_masks, + ) = self.transforms(imgs, dsp_maps, valid_masks) + + if self._has_built_in_disparity_mask or valid_masks[0] is not None: + return imgs[0], imgs[1], dsp_maps[0], valid_masks[0] + else: + return imgs[0], imgs[1], dsp_maps[0] + + def __len__(self) -> int: + return len(self._images) + + +class CarlaStereo(StereoMatchingDataset): + """ + Carla simulator data linked in the `CREStereo github repo `_. + + The dataset is expected to have the following structure: :: + + root + carla-highres + trainingF + scene1 + img0.png + img1.png + disp0GT.pfm + disp1GT.pfm + calib.txt + scene2 + img0.png + img1.png + disp0GT.pfm + disp1GT.pfm + calib.txt + ... + + Args: + root (string): Root directory where `carla-highres` is located. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: str, transforms: Optional[Callable] = None): + super().__init__(root, transforms) + + root = Path(root) / "carla-highres" + + left_image_pattern = str(root / "trainingF" / "*" / "im0.png") + right_image_pattern = str(root / "trainingF" / "*" / "im1.png") + imgs = self._scan_pairs(left_image_pattern, right_image_pattern) + self._images = imgs + + left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm") + right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm") + disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities = disparities + + def _read_disparity(self, file_path: str) -> Tuple: + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return super().__getitem__(index) + + +class Kitti2012Stereo(StereoMatchingDataset): + """ + KITTI dataset from the `2012 stereo evaluation benchmark `_. + Uses the RGB images for consistency with KITTI 2015. + + The dataset is expected to have the following structure: :: + + root + Kitti2012 + testing + colored_0 + 1_10.png + 2_10.png + ... + colored_1 + 1_10.png + 2_10.png + ... + training + colored_0 + 1_10.png + 2_10.png + ... + colored_1 + 1_10.png + 2_10.png + ... + disp_noc + 1.png + 2.png + ... + calib + + Args: + root (string): Root directory where `Kitti2012` is located. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2012" / (split + "ing") + + left_img_pattern = str(root / "colored_0" / "*_10.png") + right_img_pattern = str(root / "colored_1" / "*_10.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + disparity_pattern = str(root / "disp_noc" / "*.png") + self._disparities = self._scan_pairs(disparity_pattern, None) + else: + self._disparities = list((None, None) for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if file_path is None: + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index) + + +class Kitti2015Stereo(StereoMatchingDataset): + """ + KITTI dataset from the `2015 stereo evaluation benchmark `_. + + The dataset is expected to have the following structure: :: + + root + Kitti2015 + testing + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + training + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + disp_occ_0 + img1.png + img2.png + ... + disp_occ_1 + img1.png + img2.png + ... + calib + + Args: + root (string): Root directory where `Kitti2015` is located. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2015" / (split + "ing") + left_img_pattern = str(root / "image_2" / "*.png") + right_img_pattern = str(root / "image_3" / "*.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + left_disparity_pattern = str(root / "disp_occ_0" / "*.png") + right_disparity_pattern = str(root / "disp_occ_1" / "*.png") + self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + else: + self._disparities = list((None, None) for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple: + # test split has no disparity maps + if file_path is None: + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> Tuple: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return super().__getitem__(index)