diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 62259a604a0..1153c1b33f0 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1,3 +1,4 @@ +import bz2 import collections.abc import csv import functools @@ -1431,3 +1432,21 @@ def stanford_cars(info, root, config): make_tar(root, "car_devkit.tgz", devkit, compression="gz") return num_samples + + +@register_mock +def usps(info, root, config): + num_samples = {"train": 15, "test": 7}[config.split] + + with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh: + lines = [] + for _ in range(num_samples): + label = make_tensor(1, low=1, high=11, dtype=torch.int) + values = make_tensor(256, low=-1, high=1, dtype=torch.float) + lines.append( + " ".join([f"{int(label)}", *(f"{idx}:{float(value):.6f}" for idx, value in enumerate(values, 1))]) + ) + + fh.write("\n".join(lines).encode()) + + return num_samples diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index f7c40d432a4..f414f4e48cd 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -12,7 +12,7 @@ from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets - +from torchvision.prototype.features import Image, Label assert_samples_equal = functools.partial( assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True @@ -180,3 +180,20 @@ def test_label_matches_path(self, test_home, dataset_mock, config): for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) assert sample["label"] == label_from_path + + +@parametrize_dataset_mocks(DATASET_MOCKS["usps"]) +class TestUSPS: + def test_sample_content(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + + dataset = datasets.load(dataset_mock.name, **config) + + for sample in dataset: + assert "image" in sample + assert "label" in sample + + assert isinstance(sample["image"], Image) + assert isinstance(sample["label"], Label) + + assert sample["image"].shape == (1, 16, 16) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index feb558aa03f..1a8dc0907a4 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -17,4 +17,5 @@ from .semeion import SEMEION from .stanford_cars import StanfordCars from .svhn import SVHN +from .usps import USPS from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py new file mode 100644 index 00000000000..5df0978d031 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List + +import torch +from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, Label + + +class USPS(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "usps", + homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", + valid_options=dict( + split=("train", "test"), + ), + categories=10, + ) + + _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" + + _RESOURCES = { + "train": HttpResource( + f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f" + ), + "test": HttpResource( + f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e" + ), + } + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [USPS._RESOURCES[config.split]] + + def _prepare_sample(self, line: str) -> Dict[str, Any]: + label, *values = line.strip().split(" ") + values = [float(value.split(":")[1]) for value in values] + pixels = torch.tensor(values).add_(1).div_(2) + return dict( + image=Image(pixels.reshape(16, 16)), + label=Label(int(label) - 1, categories=self.categories), + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + ) -> IterDataPipe[Dict[str, Any]]: + dp = Decompressor(resource_dps[0]) + dp = LineReader(dp, decode=True, return_path=False) + dp = hint_sharding(dp) + dp = hint_shuffling(dp) + return Mapper(dp, self._prepare_sample)