diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 8d27240c75d..d8dd9a5c34d 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -10,7 +10,7 @@ import numpy as np import pytest import torch -from datasets_utils import create_image_folder, make_tar, make_zip +from datasets_utils import create_image_folder, make_tar, make_zip, make_fake_flo_file from torch.testing import make_tensor as _make_tensor from torchdata.datapipes.iter import IterDataPipe from torchvision.prototype import datasets @@ -490,3 +490,42 @@ def imagenet(info, root, config): make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz") return num_samples + + +@dataset_mocks.register_mock_data_fn +def sintel(info, root, config): + FLOW_H, FLOW_W = 3, 4 + + num_images_per_scene = 3 if config["split"] == "train" else 4 + num_scenes = 2 + + for split_dir in ("training", "test"): + for pass_name in ("clean", "final"): + image_root = root / split_dir / pass_name + + for scene_id in range(num_scenes): + scene_dir = image_root / f"scene_{scene_id}" + create_image_folder( + image_root, + name=str(scene_dir), + file_name_fn=lambda image_idx: f"frame_000{image_idx}.png", + num_examples=num_images_per_scene, + ) + + flow_root = root / "training" / "flow" + for scene_id in range(num_scenes): + scene_dir = flow_root / f"scene_{scene_id}" + scene_dir.mkdir(exist_ok=True, parents=True) + for i in range(num_images_per_scene - 1): + file_name = str(scene_dir / f"frame_000{i}.flo") + make_fake_flo_file(h=FLOW_H, w=FLOW_W, file_name=file_name) + + # with e.g. num_images_per_scene = 3, for a single scene with have 3 images + # which are frame_0000, frame_0001 and frame_0002 + # They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002), + # that is 3 - 1 = 2 examples. Hence the formula below + num_passes = 2 if config["pass_name"] == "both" else 1 + num_examples = (num_images_per_scene - 1) * num_scenes * num_passes + + make_zip(root, "MPI-Sintel-complete.zip", "training", "test") + return num_examples diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index caf9105006c..547d6178e33 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -30,6 +30,7 @@ def dataset_parametrization(*names, decoder=to_bytes): "caltech256", "caltech101", "imagenet", + "sintel", ) params = [] diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 62abc3119f6..589fecf0323 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -6,4 +6,5 @@ from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .sbd import SBD from .semeion import SEMEION +from .sintel import SINTEL from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 54a31edfa5c..d2a8a7bc67c 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -95,7 +95,7 @@ def _make_datapipe( def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) + dp = Filter(dp, path_comparator("name", value=self._META_FILE_NAME)) dp = Mapper(dp, self._unpickle) return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 641d584dc43..749793f1741 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -115,7 +115,7 @@ def _make_datapipe( images_dp, meta_dp = resource_dps meta_dp = ZipArchiveReader(meta_dp) - meta_dp = Filter(meta_dp, path_comparator("name", f"instances_{config.split}{config.year}.json")) + meta_dp = Filter(meta_dp, path_comparator("name", value=f"instances_{config.split}{config.year}.json")) meta_dp = JsonParser(meta_dp) meta_dp = Mapper(meta_dp, getitem(1)) meta_dp = MappingIterator(meta_dp) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index d25d55c216f..5639dd59656 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -147,7 +147,7 @@ def _make_datapipe( dp = Mapper(dp, self._collate_train_data) elif config.split == "val": devkit_dp = TarArchiveReader(devkit_dp) - devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) + devkit_dp = Filter(devkit_dp, path_comparator("name", value="ILSVRC2012_validation_ground_truth.txt")) devkit_dp = LineReader(devkit_dp, return_path=False) devkit_dp = Mapper(devkit_dp, int) devkit_dp = Enumerator(devkit_dp, 1) @@ -178,7 +178,7 @@ def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: resources = self.resources(self.default_config) devkit_dp = resources[1].to_datapipe(root / self.name) devkit_dp = TarArchiveReader(devkit_dp) - devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) + devkit_dp = Filter(devkit_dp, path_comparator("name", value="meta.mat")) meta = next(iter(devkit_dp))[1] synsets = read_mat(meta, squeeze_me=True)["synsets"] diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 888a464f69b..046c9bc5d46 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -156,7 +156,7 @@ def _make_datapipe( def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = TarArchiveReader(dp) - dp = Filter(dp, path_comparator("name", "category_names.m")) + dp = Filter(dp, path_comparator("name", value="category_names.m")) dp = LineReader(dp) dp = Mapper(dp, bytes.decode, input_col=1) lines = tuple(zip(*iter(dp)))[1] diff --git a/torchvision/prototype/datasets/_builtin/sintel.py b/torchvision/prototype/datasets/_builtin/sintel.py new file mode 100644 index 00000000000..15c517e6ab9 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/sintel.py @@ -0,0 +1,146 @@ +import io +import pathlib +import re +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, BinaryIO + +import torch +from torchdata.datapipes.iter import ( + IterDataPipe, + Demultiplexer, + Mapper, + Shuffler, + Filter, + IterKeyZipper, + ZipArchiveReader, +) +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, + DatasetType, +) +from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_flo, InScenePairer, path_accessor + + +class SINTEL(Dataset): + + _FILE_NAME_PATTERN = re.compile(r"(frame|image)_(?P\d+)[.](flo|png)") + + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "sintel", + type=DatasetType.IMAGE, + homepage="http://sintel.is.tue.mpg.de/", + valid_options=dict( + split=("train", "test"), + pass_name=("clean", "final", "both"), + ), + ) + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + archive = HttpResource( + "http://files.is.tue.mpg.de/sintel/MPI-Sintel-complete.zip", + sha256="bdc80abbe6ae13f96f6aa02e04d98a251c017c025408066a00204cd2c7104c5f", + ) + return [archive] + + def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: + path = pathlib.Path(data[0]) + # The dataset contains has the folder "training", while allowed options for `split` are + # "train" and "test", we don't check for equality here ("train" != "training") and instead + # check if split is in the folder name + return split in path.parents[2].name + + def _filter_pass_name_and_flow(self, data: Tuple[str, Any], *, pass_name: str) -> bool: + path = pathlib.Path(data[0]) + if pass_name == "both": + matched = path.parents[1].name in ["clean", "final", "flow"] + else: + matched = path.parents[1].name in [pass_name, "flow"] + return matched + + def _classify_archive(self, data: Tuple[str, Any], *, pass_name: str) -> Optional[int]: + path = pathlib.Path(data[0]) + suffix = path.suffix + if suffix == ".flo": + return 0 + elif suffix == ".png": + return 1 + else: + return None + + def _flows_key(self, data: Tuple[str, Any]) -> Tuple[str, int]: + path = pathlib.Path(data[0]) + category = path.parent.name + idx = int(self._FILE_NAME_PATTERN.match(path.name).group("idx")) # type: ignore[union-attr] + return category, idx + + def _add_fake_flow_data(self, data: Tuple[str, Any]) -> Tuple[Tuple[None, None], Tuple[str, Any]]: + return ((None, None), data) + + def _images_key(self, data: Tuple[Tuple[str, Any], Tuple[str, Any]]) -> Tuple[str, int]: + return self._flows_key(data[0]) + + def _collate_and_decode_sample( + self, + data: Tuple[Tuple[Optional[str], Optional[BinaryIO]], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]], + *, + decoder: Optional[Callable[[BinaryIO], torch.Tensor]], + ) -> Dict[str, Any]: + flow_data, images_data = data + flow_path, flow_buffer = flow_data + image1_data, image2_data = images_data + image1_path, image1_buffer = image1_data + image2_path, image2_buffer = image2_data + + return dict( + image1=decoder(image1_buffer) if decoder else image1_buffer, + image1_path=image1_path, + image2=decoder(image2_buffer) if decoder else image2_buffer, + image2_path=image2_path, + flow=read_flo(flow_buffer) if flow_buffer else None, + flow_path=flow_path, + scene=pathlib.Path(image1_path).parent.name, + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + ) -> IterDataPipe[Dict[str, Any]]: + dp = resource_dps[0] + archive_dp = ZipArchiveReader(dp) + + curr_split = Filter(archive_dp, self._filter_split, fn_kwargs=dict(split=config.split)) + filtered_curr_split = Filter( + curr_split, self._filter_pass_name_and_flow, fn_kwargs=dict(pass_name=config.pass_name) + ) + if config.split == "train": + flo_dp, pass_images_dp = Demultiplexer( + filtered_curr_split, + 2, + partial(self._classify_archive, pass_name=config.pass_name), + drop_none=True, + buffer_size=INFINITE_BUFFER_SIZE, + ) + flo_dp = Shuffler(flo_dp, buffer_size=INFINITE_BUFFER_SIZE) + pass_images_dp: IterDataPipe[Tuple[str, Any], Tuple[str, Any]] = InScenePairer( + pass_images_dp, scene_fn=path_accessor("parent", "name") + ) + zipped_dp = IterKeyZipper( + flo_dp, + pass_images_dp, + key_fn=self._flows_key, + ref_key_fn=self._images_key, + ) + else: + pass_images_dp = Shuffler(filtered_curr_split, buffer_size=INFINITE_BUFFER_SIZE) + pass_images_dp = InScenePairer(pass_images_dp, scene_fn=path_accessor("parent", "name")) + zipped_dp = Mapper(pass_images_dp, self._add_fake_flow_data) + + return Mapper(zipped_dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index a4175765555..4995f9a9d99 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -129,7 +129,7 @@ def _make_datapipe( ) split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task])) - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", value=f"{config.split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 3db10183f68..c94de7f44c8 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -21,6 +21,7 @@ Optional, IO, Sized, + Iterable, ) from typing import cast @@ -33,6 +34,19 @@ from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader from torchdata.datapipes.utils import StreamWrapper +K = TypeVar("K") +D = TypeVar("D") + +try: + from itertools import pairwise # type: ignore[attr-defined] +except ImportError: + from itertools import tee + + def pairwise(iterable: Iterable[D]) -> Iterable[Tuple[D, D]]: + a, b = tee(iterable) + next(b, None) + return zip(a, b) + __all__ = [ "INFINITE_BUFFER_SIZE", @@ -48,11 +62,9 @@ "Decompressor", "fromfile", "read_flo", + "InScenePairer", ] -K = TypeVar("K") -D = TypeVar("D") - # pseudo-infinite until a true infinite buffer is supported by all datapipes INFINITE_BUFFER_SIZE = 1_000_000_000 @@ -117,17 +129,22 @@ def getitem(*items: Any) -> Callable[[Any], Any]: return functools.partial(_getitem_closure, items=items) -def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D: - return cast(D, getattr(path, name)) +def _path_attribute_accessor(path: pathlib.Path, *, attrs: Sequence[str]) -> D: + obj: Any = path + for attr in attrs: + obj = getattr(obj, attr) + return cast(D, obj) def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D: return getter(pathlib.Path(data[0])) -def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]: - if isinstance(getter, str): - getter = functools.partial(_path_attribute_accessor, name=getter) +def path_accessor(*attrs: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]: + if not callable(attrs[0]): + getter = cast(Callable[[pathlib.Path], D], functools.partial(_path_attribute_accessor, attrs=attrs)) + else: + getter = attrs[0] return functools.partial(_path_accessor_closure, getter=getter) @@ -136,8 +153,8 @@ def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple return accessor(data) == value -def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]: - return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value) +def path_comparator(*attrs: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]: + return functools.partial(_path_comparator_closure, accessor=path_accessor(*attrs), value=value) class CompressionType(enum.Enum): @@ -321,3 +338,26 @@ def read_flo(file: BinaryIO) -> torch.Tensor: width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2) flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2) return flow.reshape((height, width, 2)).permute((2, 0, 1)) + + +class InScenePairer(IterDataPipe[Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]]): + def __init__( + self, datapipe: IterDataPipe[Tuple[str, BinaryIO]], *, scene_fn: Callable[[Tuple[str, BinaryIO]], K] + ) -> None: + self.datapipe = datapipe + self.scene_fn = scene_fn + + def __iter__(self) -> Iterator[Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]]: + prev_bytes = b"" + for (path1, stream1), (path2, stream2) in pairwise(sorted(self.datapipe)): + if self.scene_fn(path1) != self.scene_fn(path2): + prev_bytes = b"" + continue + + buffer1 = io.BytesIO(prev_bytes) if prev_bytes else stream1 + prev_bytes = stream2.read() + if prev_bytes == b"": + print() + buffer2 = io.BytesIO(prev_bytes) + + yield (path1, buffer1), (path2, buffer2)