diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index ad979b6bd84..586841ba3b2 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1113,8 +1113,8 @@ def _make_ann_file(path, num_examples, class_idx): return num_examples -# @register_mock -def clevr(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def clevr(root, config): data_folder = root / "CLEVR_v1.0" num_samples_map = { @@ -1155,7 +1155,7 @@ def clevr(info, root, config): make_zip(root, f"{data_folder.name}.zip", data_folder) - return num_samples_map[config.split] + return num_samples_map[config["split"]] class OxfordIIITPetMockData: diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index dd08a257a5b..9d322de084c 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -1,14 +1,8 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, hint_sharding, @@ -19,16 +13,30 @@ ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + +NAME = "clevr" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict() -class CLEVR(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "clevr", - homepage="https://cs.stanford.edu/people/jcjohns/clevr/", - valid_options=dict(split=("train", "val", "test")), - ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class CLEVR(Dataset2): + """ + - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/ + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", @@ -61,12 +69,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A label=Label(len(scenes_data["objects"])) if scenes_data else None, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, scenes_dp = Demultiplexer( archive_dp, @@ -76,12 +79,12 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) - images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) + images_dp = Filter(images_dp, path_comparator("parent.name", self._split)) images_dp = hint_shuffling(images_dp) images_dp = hint_sharding(images_dp) - if config.split != "test": - scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) + if self._split != "test": + scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json")) scenes_dp = JsonParser(scenes_dp) scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) scenes_dp = UnBatcher(scenes_dp) @@ -97,3 +100,6 @@ def _make_datapipe( dp = Mapper(images_dp, self._add_empty_anns) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 70_000 if self._split == "train" else 15_000