Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .sbd import SBD
from .voc import VOC
184 changes: 184 additions & 0 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import csv
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union

import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
Mapper,
Shuffler,
Filter,
ZipArchiveReader,
Zipper,
)
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
GDriveResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor


class CelebACSVParser(IterDataPipe):
def __init__(
self,
datapipe,
*,
has_header,
):
self.datapipe = datapipe
self.has_header = has_header
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)

def __iter__(self):
for _, file in self.datapipe:
file = (line.decode() for line in file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine as the whole file can fit into memory. But, to optimize it a little bit, could we change it to a streaming style by adding a decode method to yield each decoded line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would be a lot better. I'll send a patch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After revisiting this, I think it is already doing what you proposed. Writing

file = (line.decode() for line in file)

is functionally equivalent to

def decode(file):
    for line in file:
        yield line.decode()

file = decode(file)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right!


if self.has_header:
# The first row is skipped, because it only contains the number of samples
next(file)

# Empty field names are filtered out, because some files have an extr white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
Comment on lines +45 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super annoying that we can't use DictReader.
Since we have three datapipes with header, could we hard-code the fieldnames for each one and use CSVDictReader(dp, skip_lines=2, fieldnames=[...])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but there is more to it. Note that we also need to map the output so that it is a tuple with the image id first and the remaining row second. At this point we would probably have a more elaborate implementation bending everything to use the default building blocks than to just write our own.

I agree it is annoying, but writing and understanding this custom parser is not hard so I feel its warranted.


for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
yield line.pop("image_id"), line
else:
for line in csv.reader(file, **self._fmtparams):
yield line[0], line[1:]


class CelebA(Dataset):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"celeba",
type=DatasetType.IMAGE,
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
)

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
splits = GDriveResource(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
file_name="list_eval_partition.txt",
)
images = GDriveResource(
"0B7EVK8r0v71pZjFTYXZWM3FlRnM",
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
file_name="img_align_celeba.zip",
)
identities = GDriveResource(
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS",
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
file_name="identity_CelebA.txt",
)
attributes = GDriveResource(
"0B7EVK8r0v71pblRyaVFSWGxPY0U",
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
file_name="list_attr_celeba.txt",
)
bboxes = GDriveResource(
"0B7EVK8r0v71pbThiMVRxWXZ4dU0",
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
file_name="list_bbox_celeba.txt",
)
landmarks = GDriveResource(
"0B7EVK8r0v71pd0FJY3Blby1HUTQ",
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
file_name="list_landmarks_align_celeba.txt",
)
return [splits, images, identities, attributes, bboxes, landmarks]

_SPLIT_ID_TO_NAME = {
"0": "train",
"1": "valid",
"2": "test",
}

def _filter_split(self, data: Tuple[str, str], *, split):
_, split_id = data
return self._SPLIT_ID_TO_NAME[split_id[0]] == split

def _collate_anns(
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, _, image_data = split_and_image_data
path, buffer = image_data
_, ann = ann_data

image = decoder(buffer) if decoder else buffer

identity = torch.tensor(int(ann["identity"][0]))
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = {
landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
for landmark in {key[:-2] for key in ann["landmarks"].keys()}
}

return dict(
path=path,
image=image,
identity=identity,
attributes=attributes,
bbox=bbox,
landmarks=landmarks,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps

splits_dp = CelebACSVParser(splits_dp, has_header=False)
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)

images_dp = ZipArchiveReader(images_dp)

anns_dp: IterDataPipe = Zipper(
*[
CelebACSVParser(dp, has_header=has_header)
for dp, has_header in (
(identities_dp, False),
(attributes_dp, True),
(bboxes_dp, True),
(landmarks_dp, True),
)
]
)
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)

dp = KeyZipper(
splits_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
dp = KeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))