diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index eef7275f967..cc8568154ed 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1430,13 +1430,13 @@ def svhn(info, root, config): return num_samples -# @register_mock -def pcam(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def pcam(root, config): import h5py - num_images = {"train": 2, "test": 3, "val": 4}[config.split] + num_images = {"train": 2, "test": 3, "val": 4}[config["split"]] - split = "valid" if config.split == "val" else config.split + split = "valid" if config["split"] == "val" else config["split"] images_io = io.BytesIO() with h5py.File(images_io, "w") as f: diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 3d7b9547a76..1ae94da5665 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -1,13 +1,13 @@ import io +import pathlib from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple, Iterator +from typing import Any, Dict, List, Optional, Tuple, Iterator, Union +from unicodedata import category from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, OnlineResource, GDriveResource, ) @@ -17,6 +17,11 @@ ) from torchvision.prototype.features import Label +from .._api import register_dataset, register_info + + +NAME = "pcam" + class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): def __init__( @@ -40,15 +45,25 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: _Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) -class PCAM(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "pcam", - homepage="https://github.com/basveeling/pcam", - categories=2, - valid_options=dict(split=("train", "test", "val")), - dependencies=["h5py"], - ) +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=["0", "1"]) + + +@register_dataset(NAME) +class PCAM(Dataset2): + # TODO write proper docstring + """PCAM Dataset + + homepage="https://github.com/basveeling/pcam" + """ + + def __init__( + self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",)) _RESOURCES = { "train": ( @@ -89,10 +104,10 @@ def _make_info(self) -> DatasetInfo: ), } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ # = [images resource, targets resource] GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress") - for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] + for file_name, gdrive_id, sha256 in self._RESOURCES[self._split] ] def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: @@ -100,12 +115,10 @@ def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: return { "image": features.Image(image.transpose(2, 0, 1)), - "label": Label(target.item()), + "label": Label(target.item(), categories=self._categories), } - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, targets_dp = resource_dps @@ -116,3 +129,6 @@ def _make_datapipe( dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self): + return 262_144 if self._split == "train" else 32_768