11import csv
22import os
3- import warnings
43from collections import namedtuple
54from typing import Any , Callable , List , Optional , Union , Tuple
65
76import PIL
87import torch
98
10- from .utils import check_integrity , verify_str_arg
9+ from .utils import download_file_from_google_drive , check_integrity , verify_str_arg , extract_archive
1110from .vision import VisionDataset
1211
1312CSV = namedtuple ("CSV" , ["header" , "index" , "data" ])
@@ -36,17 +35,9 @@ class CelebA(VisionDataset):
3635 and returns a transformed version. E.g, ``transforms.PILToTensor``
3736 target_transform (callable, optional): A function/transform that takes in the
3837 target and transforms it.
39- download (bool, optional): Deprecated.
40-
41- .. warning::
42-
43- Downloading CelebA is not supported anymore as of 0.13 and this
44- parameter will be removed in 0.15. See
45- `this issue <https://github.com/pytorch/vision/issues/5705>`__
46- for more details.
47- Please download the files from
48- https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract
49- them in ``root/celeba``.
38+ download (bool, optional): If true, downloads the dataset from the internet and
39+ puts it in root directory. If dataset is already downloaded, it is not
40+ downloaded again.
5041 """
5142
5243 base_folder = "celeba"
@@ -73,7 +64,7 @@ def __init__(
7364 target_type : Union [List [str ], str ] = "attr" ,
7465 transform : Optional [Callable ] = None ,
7566 target_transform : Optional [Callable ] = None ,
76- download : bool = None ,
67+ download : bool = False ,
7768 ) -> None :
7869 super ().__init__ (root , transform = transform , target_transform = target_transform )
7970 self .split = split
@@ -85,15 +76,6 @@ def __init__(
8576 if not self .target_type and self .target_transform is not None :
8677 raise RuntimeError ("target_transform is specified but target_type is empty" )
8778
88- if download is not None :
89- warnings .warn (
90- "Downloading CelebA is not supported anymore as of 0.13, and the "
91- "download parameter will be removed in 0.15. See "
92- "https://github.com/pytorch/vision/issues/5705 for more details. "
93- "Please download the files from "
94- "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
95- "in ``root/celeba``."
96- )
9779 if download :
9880 self .download ()
9981
@@ -164,14 +146,10 @@ def download(self) -> None:
164146 print ("Files already downloaded and verified" )
165147 return
166148
167- raise ValueError (
168- "Downloading CelebA is not supported anymore as of 0.13, and the "
169- "download parameter will be removed in 0.15. See "
170- "https://github.com/pytorch/vision/issues/5705 for more details. "
171- "Please download the files from "
172- "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them "
173- "in ``root/celeba``."
174- )
149+ for (file_id , md5 , filename ) in self .file_list :
150+ download_file_from_google_drive (file_id , os .path .join (self .root , self .base_folder ), filename , md5 )
151+
152+ extract_archive (os .path .join (self .root , self .base_folder , "img_align_celeba.zip" ))
175153
176154 def __getitem__ (self , index : int ) -> Tuple [Any , Any ]:
177155 X = PIL .Image .open (os .path .join (self .root , self .base_folder , "img_align_celeba" , self .filename [index ]))
0 commit comments