diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 20606424319..d3111d49730 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -695,9 +695,9 @@ def generate(cls, root): return num_samples_map -# @register_mock -def sbd(info, root, config): - return SBDMockData.generate(root)[config.split] +@register_mock(configs=combinations_grid(split=("train", "val", "train_noval"))) +def sbd(root, config): + return SBDMockData.generate(root)[config["split"]] @register_mock(configs=[dict()]) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index bcacaea2d24..d062d78fe0a 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,6 +1,6 @@ import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -11,13 +11,7 @@ IterKeyZipper, LineReader, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -26,22 +20,44 @@ path_comparator, hint_sharding, hint_shuffling, + BUILTIN_DIR, ) from torchvision.prototype.features import _Feature, EncodedImage +from .._api import register_dataset, register_info + +NAME = "sbd" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) -class SBD(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "sbd", - dependencies=("scipy",), - homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", - valid_options=dict( - split=("train", "val", "train_noval"), - ), - ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + + +@register_dataset(NAME) +class SBD(Dataset2): + """ + - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval")) + + self._categories = CATEGORIES + + super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", @@ -85,12 +101,7 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st segmentation=_Feature(anns["Segmentation"].item()), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp, extra_split_dp = resource_dps archive_dp = resource_dps[0] @@ -101,10 +112,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) - if config.split == "train_noval": + if self._split == "train_noval": split_dp = extra_split_dp - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_shuffling(split_dp) split_dp = hint_sharding(split_dp) @@ -120,10 +131,17 @@ def _make_datapipe( ) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: - resources = self._resources(self.default_config) + def __len__(self) -> int: + return { + "train": 8_498, + "val": 2_857, + "train_noval": 5_623, + }[self._split] + + def _generate_categories(self) -> Tuple[str, ...]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "category_names.m")) dp = LineReader(dp) dp = Mapper(dp, bytes.decode, input_col=1)