From 06c885fcdea9d8602f7ce076b1dfe46e22e336c2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 12:52:20 +0100 Subject: [PATCH] Migrate Semeion prototype dataset --- test/builtin_dataset_mocks.py | 6 +-- .../prototype/datasets/_builtin/semeion.py | 50 +++++++++++-------- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0210a4dacec..5b17d5fe9e2 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -679,10 +679,10 @@ def sbd(info, root, config): return SBDMockData.generate(root)[config.split] -# @register_mock -def semeion(info, root, config): +@register_mock(configs=[dict()]) +def semeion(root, config): num_samples = 3 - num_categories = len(info.categories) + num_categories = 10 images = torch.rand(num_samples, 256) labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index fb64c051d6c..e3a802d3cee 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,31 +1,43 @@ -from typing import Any, Dict, List, Tuple +import pathlib +from typing import Any, Dict, List, Tuple, Union import torch +from pytest import skip from torchdata.datapipes.iter import ( IterDataPipe, Mapper, CSVParser, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, HttpResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, OneHotLabel +from .._api import register_dataset, register_info + +NAME = "semeion" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(i) for i in range(10)]) -class SEMEION(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "semeion", - categories=10, - homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", - ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class SEMEION(Dataset2): + """Semeion dataset + homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", + """ + + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: data = HttpResource( "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data", sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1", @@ -36,18 +48,16 @@ def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: image_data, label_data = data[:256], data[256:-1] return dict( - image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)), - label=OneHotLabel([int(label) for label in label_data], categories=self.categories), + image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)), + label=OneHotLabel([int(label) for label in label_data], categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVParser(dp, delimiter=" ") dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 1_593