Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ def generate(
make_tar(root, name, folder, compression="gz")


# @register_mock
def cifar10(info, root, config):
@register_mock(configs=combinations_grid(split=("train", "test")))
def cifar10(root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"]

Expand All @@ -349,11 +349,11 @@ def cifar10(info, root, config):
labels_key="labels",
)

return len(train_files if config.split == "train" else test_files)
return len(train_files if config["split"] == "train" else test_files)


# @register_mock
def cifar100(info, root, config):
@register_mock(configs=combinations_grid(split=("train", "test")))
def cifar100(root, config):
train_files = ["train"]
test_files = ["test"]

Expand All @@ -367,7 +367,7 @@ def cifar100(info, root, config):
labels_key="fine_labels",
)

return len(train_files if config.split == "train" else test_files)
return len(train_files if config["split"] == "train" else test_files)


# @register_mock
Expand Down
99 changes: 60 additions & 39 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,21 @@
import abc
import functools
import io
import pathlib
import pickle
from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO
from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union

import numpy as np
from torchdata.datapipes.iter import (
IterDataPipe,
Filter,
Mapper,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
path_comparator,
hint_sharding,
)
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR
from torchvision.prototype.features import Label, Image

from .._api import register_dataset, register_info


class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
Expand All @@ -38,25 +29,29 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
yield from iter(zip(image_arrays, category_idcs))


class _CifarBase(Dataset):
class _CifarBase(Dataset2):
_FILE_NAME: str
_SHA256: str
_LABELS_KEY: str
_META_FILE_NAME: str
_CATEGORIES_KEY: str
# _categories: List[str]

def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)

@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]:
def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
pass

def _make_info(self) -> DatasetInfo:
return DatasetInfo(
type(self).__name__.lower(),
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(split=("train", "test")),
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
Expand All @@ -72,52 +67,78 @@ def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
return dict(
image=Image(image_array),
label=Label(category_idx, categories=self.categories),
label=Label(category_idx, categories=self._categories),
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split))
dp = Filter(dp, self._is_data_file)
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)

def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
def __len__(self) -> int:
return 50_000 if self._split == "train" else 10_000

def _generate_categories(self) -> List[str]:
resources = self._resources()

dp = resources[0].load(root)
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)

return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])


CIFAR10_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar10.categories"))


@register_info("cifar10")
def _cifar10_info() -> Dict[str, Any]:
return dict(categories=CIFAR10_CATEGORIES)


@register_dataset("cifar10")
class Cifar10(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""

_FILE_NAME = "cifar-10-python.tar.gz"
_SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
_LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names"
_categories = _cifar10_info()["categories"]

def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if split == "train" else "test")
return path.name.startswith("data" if self._split == "train" else "test")


CIFAR100_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar100.categories"))


@register_info("cifar100")
def _cifar100_info() -> Dict[str, Any]:
return dict(categories=CIFAR10_CATEGORIES)


@register_dataset("cifar100")
class Cifar100(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""

_FILE_NAME = "cifar-100-python.tar.gz"
_SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
_LABELS_KEY = "fine_labels"
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
_categories = _cifar100_info()["categories"]

def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == split
return path.name == self._split