Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
CREStereo
FallingThingsStereo
SceneFlowStereo
SintelStereo
Expand Down
31 changes: 31 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,6 +2841,37 @@ def test_train_splits(self):
datasets_utils.shape_test_for_stereo(left, right, disparity)


class CREStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CREStereo
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, np.ndarray, type(None))

def inject_fake_data(self, tmpdir, config):
crestereo_dir = pathlib.Path(tmpdir) / "CREStereo"
os.makedirs(crestereo_dir, exist_ok=True)

examples = {"tree": 2, "shapenet": 3, "reflective": 6, "hole": 5}

for category_name in ["shapenet", "reflective", "tree", "hole"]:
split_dir = crestereo_dir / category_name
os.makedirs(split_dir, exist_ok=True)
num_examples = examples[category_name]

for idx in range(num_examples):
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.jpg", size=(100, 100))
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.jpg", size=(100, 100))
# these are going to end up being gray scale images
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.disp.png", size=(1, 100, 100))
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.disp.png", size=(1, 100, 100))

return sum(examples.values())

def test_splits(self):
with self.create_dataset() as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo(left, right, disparity)


class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FallingThingsStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(variant=("single", "mixed", "both"))
Expand Down
2 changes: 2 additions & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
Expand Down Expand Up @@ -118,6 +119,7 @@
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
"CREStereo",
"FallingThingsStereo",
"SceneFlowStereo",
"SintelStereo",
Expand Down
88 changes: 88 additions & 0 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,94 @@ def __getitem__(self, index: int) -> Tuple:
return super().__getitem__(index)


class CREStereo(StereoMatchingDataset):
"""Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.

The dataset is expected to have the following structure: ::

root
CREStereo
tree
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
img2_left.jpg
img2_right.jpg
img2_left.disp.jpg
img2_right.disp.jpg
...
shapenet
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
reflective
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
hole
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...

Args:
root (str): Root directory of the dataset.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms)

root = Path(root) / "CREStereo"

dirs = ["shapenet", "reflective", "tree", "hole"]

for s in dirs:
left_image_pattern = str(root / s / "*_left.jpg")
right_image_pattern = str(root / s / "*_right.jpg")
imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
self._images += imgs

left_disparity_pattern = str(root / s / "*_left.disp.png")
right_disparity_pattern = str(root / s / "*_right.disp.png")
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities

def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :] / 256.0
valid_mask = None
return disparity_map, valid_mask

def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.

Args:
index(int): The index of the example to retrieve

Returns:
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
"""
return super().__getitem__(index)


class FallingThingsStereo(StereoMatchingDataset):
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.

Expand Down